diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3e6fc4626..728942b8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,6 @@ jobs: matrix: java: [ "11", "17" ] scala: [ "2.12.18", "2.13.12", "3.3.1" ] - elasticsearch: ["7.x", "8.x"] steps: - name: Checkout current branch uses: actions/checkout@v4.1.1 @@ -55,7 +54,7 @@ jobs: - name: Run tests run: ./sbt ++${{ matrix.scala }}! library/test - name: Run test container - run: docker-compose -f docker/elasticsearch-${{ matrix.elasticsearch }}.yml up -d + run: docker-compose -f docker/elasticsearch-8.x.yml up -d - name: Run integration tests run: ./sbt ++${{ matrix.scala }}! integration/test diff --git a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala index 0e085e07e..dcaede455 100644 --- a/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala +++ b/modules/integration/src/test/scala/zio/elasticsearch/HttpExecutorSpec.scala @@ -1049,6 +1049,61 @@ object HttpExecutorSpec extends IntegrationSpec { Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie ) ), + suite("kNN search")( + test("search for top two results") { + checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) { + (firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) => + for { + _ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll)) + firstDocumentUpdated = firstDocument.copy(vectorField = List(1, 5, -20)) + secondDocumentUpdated = secondDocument.copy(vectorField = List(42, 8, -15)) + thirdDocumentUpdated = thirdDocument.copy(vectorField = List(15, 11, 23)) + req1 = ElasticRequest.create(firstSearchIndex, firstDocumentId, firstDocumentUpdated) + req2 = ElasticRequest.create(firstSearchIndex, secondDocumentId, secondDocumentUpdated) + req3 = ElasticRequest.create(firstSearchIndex, thirdDocumentId, thirdDocumentUpdated) + _ <- Executor.execute(ElasticRequest.bulk(req1, req2, req3).refreshTrue) + query = ElasticQuery.kNN(TestDocument.vectorField, 2, 3, Chunk(-5.0, 9.0, -12.0)) + res <- Executor.execute(ElasticRequest.knnSearch(firstSearchIndex, query)).documentAs[TestDocument] + } yield (assert(res)(equalTo(Chunk(firstDocumentUpdated, thirdDocumentUpdated)))) + } + } @@ around( + Executor.execute( + ElasticRequest.createIndex( + firstSearchIndex, + """{ "mappings": { "properties": { "vectorField": { "type": "dense_vector", "dims": 3, "similarity": "l2_norm", "index": true } } } }""" + ) + ), + Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie + ), + test("search for top two results with filters") { + checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) { + (firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) => + for { + _ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll)) + firstDocumentUpdated = firstDocument.copy(intField = 15, vectorField = List(1, 5, -20)) + secondDocumentUpdated = secondDocument.copy(intField = 21, vectorField = List(42, 8, -15)) + thirdDocumentUpdated = thirdDocument.copy(intField = 4, vectorField = List(15, 11, 23)) + req1 = ElasticRequest.create(firstSearchIndex, firstDocumentId, firstDocumentUpdated) + req2 = ElasticRequest.create(firstSearchIndex, secondDocumentId, secondDocumentUpdated) + req3 = ElasticRequest.create(firstSearchIndex, thirdDocumentId, thirdDocumentUpdated) + _ <- Executor.execute(ElasticRequest.bulk(req1, req2, req3).refreshTrue) + query = ElasticQuery.kNN(TestDocument.vectorField, 2, 3, Chunk(-5.0, 9.0, -12.0)) + filter = ElasticQuery.range(TestDocument.intField).gt(10) + res <- Executor + .execute(ElasticRequest.knnSearch(firstSearchIndex, query).filter(filter)) + .documentAs[TestDocument] + } yield (assert(res)(equalTo(Chunk(firstDocumentUpdated, secondDocumentUpdated)))) + } + } @@ around( + Executor.execute( + ElasticRequest.createIndex( + firstSearchIndex, + """{ "mappings": { "properties": { "vectorField": { "type": "dense_vector", "dims": 3, "similarity": "l2_norm", "index": true } } } }""" + ) + ), + Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie + ) + ) @@ shrinks(0), suite("searching for documents")( test("search for a document using a boosting query") { checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) { diff --git a/modules/integration/src/test/scala/zio/elasticsearch/IntegrationSpec.scala b/modules/integration/src/test/scala/zio/elasticsearch/IntegrationSpec.scala index 1a1ef0dfa..892516fce 100644 --- a/modules/integration/src/test/scala/zio/elasticsearch/IntegrationSpec.scala +++ b/modules/integration/src/test/scala/zio/elasticsearch/IntegrationSpec.scala @@ -83,6 +83,7 @@ trait IntegrationSpec extends ZIOSpecDefault { doubleField <- Gen.double(100, 2000) booleanField <- Gen.boolean geoPointField <- genGeoPoint + vectorField <- Gen.listOfN(5)(Gen.int(-10, 10)) } yield TestDocument( stringField = stringField, dateField = dateField, @@ -90,7 +91,8 @@ trait IntegrationSpec extends ZIOSpecDefault { intField = intField, doubleField = doubleField, booleanField = booleanField, - geoPointField = geoPointField + geoPointField = geoPointField, + vectorField = vectorField ) def genTestSubDocument: Gen[Any, TestSubDocument] = for { diff --git a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala index 13f2e5502..6582d89c7 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/ElasticQuery.scala @@ -546,6 +546,46 @@ object ElasticQuery { final def ids(value: String, values: String*): IdsQuery[Any] = Ids(values = Chunk.fromIterable(value +: values)) + /** + * Constructs a type-safe instance of [[zio.elasticsearch.query.KNNQuery]] using the specified parameters. + * [[zio.elasticsearch.query.KNNQuery]] is used to perform a k-nearest neighbor (kNN) search and returns the matching + * documents. + * + * @param field + * the type-safe field for which query is specified for + * @param k + * number of nearest neighbors to return as top hits (must be less than `numCandidates`) + * @param numCandidates + * number of nearest neighbor candidates to consider per shard + * @param queryVector + * query vector + * @tparam S + * document for which field query is executed + * @return + * an instance of [[zio.elasticsearch.query.KNNQuery]] that represents the kNN query to be performed. + */ + final def kNN[S](field: Field[S, _], k: Int, numCandidates: Int, queryVector: Chunk[Double]): KNNQuery[S] = + KNN(field = field.toString, k = k, numCandidates = numCandidates, queryVector = queryVector, similarity = None) + + /** + * Constructs an instance of [[zio.elasticsearch.query.KNNQuery]] using the specified parameters. + * [[zio.elasticsearch.query.KNNQuery]] is used to perform a k-nearest neighbor (kNN) search and returns the matching + * documents. + * + * @param field + * the field for which query is specified for + * @param k + * number of nearest neighbors to return as top hits (must be less than `numCandidates`) + * @param numCandidates + * number of nearest neighbor candidates to consider per shard + * @param queryVector + * query vector + * @return + * an instance of [[zio.elasticsearch.query.KNNQuery]] that represents the kNN query to be performed. + */ + final def kNN(field: String, k: Int, numCandidates: Int, queryVector: Chunk[Double]): KNNQuery[Any] = + KNN(field = field, k = k, numCandidates = numCandidates, queryVector = queryVector, similarity = None) + /** * Constructs an instance of [[zio.elasticsearch.query.MatchAllQuery]] used for matching all documents. * diff --git a/modules/library/src/main/scala/zio/elasticsearch/ElasticRequest.scala b/modules/library/src/main/scala/zio/elasticsearch/ElasticRequest.scala index fcfe0738b..4639f4af5 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/ElasticRequest.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/ElasticRequest.scala @@ -22,13 +22,14 @@ import zio.elasticsearch.IndexSelector.IndexNameSyntax import zio.elasticsearch.aggregation.ElasticAggregation import zio.elasticsearch.executor.response.BulkResponse import zio.elasticsearch.highlights.Highlights -import zio.elasticsearch.query.ElasticQuery import zio.elasticsearch.query.sort.Sort +import zio.elasticsearch.query.{ElasticQuery, KNNQuery} import zio.elasticsearch.request._ import zio.elasticsearch.request.options._ import zio.elasticsearch.result.{ AggregateResult, GetResult, + KNNSearchResult, SearchAndAggregateResult, SearchResult, UpdateByQueryResult @@ -215,6 +216,20 @@ object ElasticRequest { final def getById(index: IndexName, id: DocumentId): GetByIdRequest = GetById(index = index, id = id, refresh = None, routing = None) + /** + * Constructs an instance of [[KNNRequest]] used for performing a k-nearest neighbour (kNN) search. Given a query + * vector, it finds the k closest vectors and returns those documents as search hits. + * + * @param selectors + * the name of the index or more indices to search in + * @param query + * an instance of [[zio.elasticsearch.query.KNNQuery]] to run + * @return + * an instance of [[KNNRequest]] that represents k-nearest neighbour (kNN) operation to be performed. + */ + final def knnSearch[I: IndexSelector](selectors: I, query: KNNQuery[_]): KNNRequest = + KNN(knn = query, selectors = selectors.toSelector, filter = None, routing = None) + /** * Constructs an instance of [[RefreshRequest]] used for refreshing an index with the specified name. * @@ -593,6 +608,40 @@ object ElasticRequest { self.copy(routing = Some(value)) } + sealed trait KNNRequest extends ElasticRequest[KNNSearchResult] with HasRouting[KNNRequest] { + + /** + * Adds an [[zio.elasticsearch.ElasticQuery]] to the [[zio.elasticsearch.ElasticRequest.KNNRequest]] to filter the + * documents that can match. If not provided, all documents are allowed to match. + * + * @param query + * the Elastic query to be added + * @return + * an instance of a [[zio.elasticsearch.ElasticRequest.KNNRequest]] that represents the kNN search operation + * enriched with filter query to be performed. + */ + def filter(query: ElasticQuery[_]): KNNRequest + } + + private[elasticsearch] final case class KNN( + knn: KNNQuery[_], + selectors: String, + filter: Option[ElasticQuery[_]], + routing: Option[Routing] + ) extends KNNRequest { self => + + def filter(query: ElasticQuery[_]): KNNRequest = + self.copy(filter = Some(query)) + + def routing(value: Routing): KNNRequest = + self.copy(routing = Some(value)) + + private[elasticsearch] def toJson: Json = { + val filterJson: Json = filter.fold(Obj())(f => Obj("filter" -> f.toJson(None))) + Obj("knn" -> knn.toJson) merge filterJson + } + } + sealed trait RefreshRequest extends ElasticRequest[Boolean] private[elasticsearch] final case class Refresh(selectors: String) extends RefreshRequest @@ -612,7 +661,7 @@ object ElasticRequest { * [[zio.elasticsearch.ElasticRequest.SearchRequest]]. * * @param aggregation - * the elastic aggregation to be added + * the Elastic aggregation to be added * @return * an instance of a [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] that represents search and * aggregate operations to be performed. diff --git a/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala b/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala index cd0246e07..902241733 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/executor/HttpExecutor.scala @@ -79,6 +79,7 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig case r: DeleteIndex => executeDeleteIndex(r) case r: Exists => executeExists(r) case r: GetById => executeGetById(r) + case r: KNN => executeKnn(r) case r: Refresh => executeRefresh(r) case r: Search => executeSearch(r) case r: SearchAndAggregate => executeSearchAndAggregate(r) @@ -372,6 +373,31 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig } } + private def executeKnn(r: KNN): Task[KNNSearchResult] = { + val uri = uri"${esConfig.uri}/${r.selectors}/_knn_search".withParams( + getQueryParams(Chunk(("routing", r.routing))) + ) + + sendRequestWithCustomResponse[SearchWithAggregationsResponse]( + baseRequest + .post(uri) + .response(asJson[SearchWithAggregationsResponse]) + .contentType(ApplicationJson) + .body(r.toJson) + ).flatMap { response => + response.code match { + case HttpOk => + response.body.fold( + e => ZIO.fail(new ElasticException(s"Exception occurred: ${e.getMessage}")), + value => + ZIO.succeed(new KNNSearchResult(itemsFromDocumentsWithHighlights(value.resultsWithHighlightsAndSort))) + ) + case _ => + ZIO.fail(handleFailuresFromCustomResponse(response)) + } + } + } + private def executeRefresh(r: Refresh): Task[Boolean] = sendRequest(baseRequest.get(uri"${esConfig.uri}/${r.selectors}/$Refresh")).flatMap { response => response.code match { diff --git a/modules/library/src/main/scala/zio/elasticsearch/query/Queries.scala b/modules/library/src/main/scala/zio/elasticsearch/query/Queries.scala index 42aff4402..bf8f88f2e 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/query/Queries.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/query/Queries.scala @@ -266,6 +266,44 @@ private[elasticsearch] final case class Exists[S](field: String, boost: Option[D ) } +sealed trait KNNQuery[-S] { self => + + /** + * Sets the `similarity` parameter for the [[zio.elasticsearch.query.KNNQuery]]. The `similarity` parameter is the + * required minimum similarity for a vector to be considered a match. + * + * @param value + * a non-negative real number used for the `similarity` + * @return + * an instance of [[zio.elasticsearch.query.KNNQuery]] enriched with the `similarity` parameter. + */ + def similarity(value: Double): KNNQuery[S] + + private[elasticsearch] def toJson: Json +} + +private[elasticsearch] final case class KNN[S]( + field: String, + k: Int, + numCandidates: Int, + queryVector: Chunk[Double], + similarity: Option[Double] +) extends KNNQuery[S] { self => + + def similarity(value: Double): KNN[S] = + self.copy(similarity = Some(value)) + + private[elasticsearch] def toJson: Json = { + val similarityJson = similarity.fold(Obj())(s => Obj("similarity" -> s.toJson)) + Obj( + "field" -> field.toJson, + "query_vector" -> Arr(queryVector.map(_.toJson)), + "k" -> k.toJson, + "num_candidates" -> numCandidates.toJson + ) merge similarityJson + } +} + sealed trait FunctionScoreQuery[S] extends ElasticQuery[S] with HasBoost[FunctionScoreQuery[S]] { /** diff --git a/modules/library/src/main/scala/zio/elasticsearch/result/ElasticResult.scala b/modules/library/src/main/scala/zio/elasticsearch/result/ElasticResult.scala index 0e2f91ba0..e381d4257 100644 --- a/modules/library/src/main/scala/zio/elasticsearch/result/ElasticResult.scala +++ b/modules/library/src/main/scala/zio/elasticsearch/result/ElasticResult.scala @@ -106,6 +106,18 @@ final class GetResult private[elasticsearch] (private val doc: Option[Item]) ext }) } +final class KNNSearchResult private[elasticsearch] (private val hits: Chunk[Item]) extends DocumentResult[Chunk] { + + def documentAs[A: Schema]: IO[DecodingException, Chunk[A]] = + ZIO.fromEither { + ZValidation.validateAll(hits.map(item => ZValidation.fromEither(item.documentAs))).toEitherWith { errors => + DecodingException(s"Could not parse all documents successfully: ${errors.map(_.message).mkString(", ")}") + } + } + + lazy val items: UIO[Chunk[Item]] = ZIO.succeed(hits) +} + final class SearchResult private[elasticsearch] ( private val hits: Chunk[Item], private val fullResponse: SearchWithAggregationsResponse diff --git a/modules/library/src/test/scala/zio/elasticsearch/ElasticQuerySpec.scala b/modules/library/src/test/scala/zio/elasticsearch/ElasticQuerySpec.scala index 698eb0b29..47d7484c4 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/ElasticQuerySpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/ElasticQuerySpec.scala @@ -1074,6 +1074,93 @@ object ElasticQuerySpec extends ZIOSpecDefault { ) ) }, + test("kNN") { + val queryString = kNN("stringField", 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryBool = kNN("boolField", 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryInt = kNN("intField", 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryStringTs = kNN(TestDocument.stringField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryBoolTs = kNN(TestDocument.booleanField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryIntTs = kNN(TestDocument.intField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryWithSimilarity = kNN(TestDocument.stringField, 5, 10, Chunk(1.1, 2.2, 3.3)).similarity(3.14) + + assert(queryString)( + equalTo( + KNN[Any]( + field = "stringField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryBool)( + equalTo( + KNN[Any]( + field = "boolField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryInt)( + equalTo( + KNN[Any]( + field = "intField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryStringTs)( + equalTo( + KNN[TestDocument]( + field = "stringField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryBoolTs)( + equalTo( + KNN[TestDocument]( + field = "booleanField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryIntTs)( + equalTo( + KNN[TestDocument]( + field = "intField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = None + ) + ) + ) && + assert(queryWithSimilarity)( + equalTo( + KNN[TestDocument]( + field = "stringField", + k = 5, + numCandidates = 10, + queryVector = Chunk(1.1, 2.2, 3.3), + similarity = Some(3.14) + ) + ) + ) + }, test("matchAll") { val query = matchAll val queryWithBoost = matchAll.boost(3.14) @@ -3328,6 +3415,58 @@ object ElasticQuerySpec extends ZIOSpecDefault { assert(query.toJson(fieldPath = None))(equalTo(expected.toJson)) }, + test("kNN") { + val queryString = kNN(TestDocument.stringField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryBool = kNN(TestDocument.booleanField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryInt = kNN(TestDocument.intField, 5, 10, Chunk(1.1, 2.2, 3.3)) + val queryWithSimilarity = kNN(TestDocument.stringField, 5, 10, Chunk(1.1, 2.2, 3.3)).similarity(3.14) + + val expectedString = + """ + |{ + | "field": "stringField", + | "query_vector": [1.1, 2.2, 3.3], + | "k": 5, + | "num_candidates": 10 + |} + |""".stripMargin + + val expectedBool = + """ + |{ + | "field": "booleanField", + | "query_vector": [1.1, 2.2, 3.3], + | "k": 5, + | "num_candidates": 10 + |} + |""".stripMargin + + val expectedInt = + """ + |{ + | "field": "intField", + | "query_vector": [1.1, 2.2, 3.3], + | "k": 5, + | "num_candidates": 10 + |} + |""".stripMargin + + val expectedWithSimilarity = + """ + |{ + | "field": "stringField", + | "query_vector": [1.1, 2.2, 3.3], + | "k": 5, + | "num_candidates": 10, + | "similarity": 3.14 + |} + |""".stripMargin + + assert(queryString.toJson)(equalTo(expectedString.toJson)) && + assert(queryBool.toJson)(equalTo(expectedBool.toJson)) && + assert(queryInt.toJson)(equalTo(expectedInt.toJson)) && + assert(queryWithSimilarity.toJson)(equalTo(expectedWithSimilarity.toJson)) + }, test("matchAll") { val query = matchAll val queryWithBoost = matchAll.boost(3.14) diff --git a/modules/library/src/test/scala/zio/elasticsearch/ElasticRequestSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/ElasticRequestSpec.scala index ce382d449..699df774a 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/ElasticRequestSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/ElasticRequestSpec.scala @@ -261,6 +261,36 @@ object ElasticRequestSpec extends ZIOSpecDefault { equalTo(GetById(index = Index, id = DocId, refresh = Some(true), routing = Some(RoutingValue))) ) }, + test("knnSearch") { + val knnSearchRequest = knnSearch(selectors = Index, query = KnnQuery) + val knnSearchRequestWithFilter = knnSearch(selectors = Index, query = KnnQuery).filter(query = Query) + val knnSearchRequestWithRouting = + knnSearch(selectors = Index, query = KnnQuery.similarity(3.14)).routing(RoutingValue) + val knnSearchRequestWithAllParams = + knnSearch(selectors = Index, query = KnnQuery).filter(query = Query).routing(RoutingValue) + + assert(knnSearchRequest)( + equalTo(KNN(knn = KnnQuery, selectors = Index.toSelector, filter = None, routing = None)) + ) && + assert(knnSearchRequestWithFilter)( + equalTo(KNN(knn = KnnQuery, selectors = Index.toSelector, filter = Some(Query), routing = None)) + ) && + assert(knnSearchRequestWithRouting)( + equalTo( + KNN( + knn = KnnQuery.similarity(3.14), + selectors = Index.toSelector, + filter = None, + routing = Some(RoutingValue) + ) + ) + ) && + assert(knnSearchRequestWithAllParams)( + equalTo( + KNN(knn = KnnQuery, selectors = Index.toSelector, filter = Some(Query), routing = Some(RoutingValue)) + ) + ) + }, test("refresh") { val refreshRequest = refresh(Index) val refreshWithMultiIndex = refresh(Indices) @@ -999,9 +1029,9 @@ object ElasticRequestSpec extends ZIOSpecDefault { val expected = """|{ "create" : { "_index" : "index", "routing" : "routing" } } - |{"stringField":"stringField1","subDocumentList":[],"dateField":"2020-10-10","intField":5,"doubleField":7.0,"booleanField":true,"geoPointField":{"lat":20.0,"lon":21.0}} + |{"stringField":"stringField1","subDocumentList":[],"dateField":"2020-10-10","intField":5,"doubleField":7.0,"booleanField":true,"geoPointField":{"lat":20.0,"lon":21.0},"vectorField":[]} |{ "index" : { "_index" : "index", "_id" : "documentid" } } - |{"stringField":"stringField2","subDocumentList":[],"dateField":"2022-10-10","intField":10,"doubleField":17.0,"booleanField":false,"geoPointField":{"lat":10.0,"lon":11.0}} + |{"stringField":"stringField2","subDocumentList":[],"dateField":"2022-10-10","intField":10,"doubleField":17.0,"booleanField":false,"geoPointField":{"lat":10.0,"lon":11.0},"vectorField":[]} |""".stripMargin assert(requestBody)(equalTo(expected)) @@ -1049,7 +1079,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 20.0, | "lon": 21.0 - | } + | }, + | "vectorField": [] |} |""".stripMargin @@ -1072,7 +1103,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 20.0, | "lon": 21.0 - | } + | }, + | "vectorField": [] |} |""".stripMargin @@ -1097,48 +1129,70 @@ object ElasticRequestSpec extends ZIOSpecDefault { assert(jsonRequest)(equalTo("")) && assert(jsonRequestWithDefinition)(equalTo(definition)) }, - test("upsert") { - val jsonRequest = upsert(index = Index, id = DocId, doc = Doc1) match { - case r: CreateOrUpdate => r.toJson + test("deleteByQuery") { + val jsonRequest = deleteByQuery(index = Index, query = Query) match { + case r: DeleteByQuery => r.toJson } val expected = """ |{ - | "stringField": "stringField1", - | "subDocumentList": [], - | "dateField": "2020-10-10", - | "intField": 5, - | "doubleField": 7.0, - | "booleanField": true, - | "geoPointField": { - | "lat": 20.0, - | "lon": 21.0 + | "query" : { + | "range" : { + | "intField" : { + | "gte" : 10 + | } + | } | } |} |""".stripMargin assert(jsonRequest)(equalTo(expected.toJson)) }, - test("deleteByQuery") { - val jsonRequest = deleteByQuery(index = Index, query = Query) match { - case r: DeleteByQuery => r.toJson + test("knnSearch") { + val jsonRequest = knnSearch(selectors = Index, query = KnnQuery) match { + case r: ElasticRequest.KNN => r.toJson } + val jsonRequestWithFilter = + knnSearch(selectors = Index, query = KnnQuery.similarity(3.14)).filter(query = Query) match { + case r: ElasticRequest.KNN => r.toJson + } val expected = """ |{ - | "query" : { - | "range" : { - | "intField" : { - | "gte" : 10 + | "knn": { + | "field": "stringField", + | "query_vector": [1.1, 3.3], + | "k": 10, + | "num_candidates": 21 + | } + |} + |""".stripMargin + + val expectedWithFilter = + """ + |{ + | "knn": { + | "field": "stringField", + | "query_vector": [1.1, 3.3], + | "k": 10, + | "num_candidates": 21, + | "similarity": 3.14 + | }, + | "filter": { + | "range": { + | "intField": { + | "gte": 10 | } | } | } |} |""".stripMargin - assert(jsonRequest)(equalTo(expected.toJson)) + assert(jsonRequest)(equalTo(expected.toJson)) && assert(jsonRequestWithFilter)( + equalTo(expectedWithFilter.toJson) + ) }, test("search") { val jsonRequest = search(Index, Query) match { @@ -1274,7 +1328,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "doubleField", | "booleanField", | "geoPointField.lat", - | "geoPointField.lon" + | "geoPointField.lon", + | "vectorField" | ] | } |} @@ -1475,7 +1530,7 @@ object ElasticRequestSpec extends ZIOSpecDefault { assert(jsonRequestWithSortAndHighlights)(equalTo(expectedWithSortAndHighlights.toJson)) && assert(jsonRequestWithAllParams)(equalTo(expectedWithAllParams.toJson)) }, - test("update - doc") { + test("update") { val jsonRequest = update(index = Index, id = DocId, doc = Doc1) match { case r: Update => r.toJson } @@ -1496,7 +1551,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 20.0, | "lon": 21.0 - | } + | }, + | "vectorField": [] | } |} |""".stripMargin @@ -1514,7 +1570,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 20.0, | "lon": 21.0 - | } + | }, + | "vectorField": [] | }, | "upsert": { | "stringField": "stringField2", @@ -1526,7 +1583,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 10.0, | "lon": 11.0 - | } + | }, + | "vectorField": [] | } |} |""".stripMargin @@ -1534,12 +1592,12 @@ object ElasticRequestSpec extends ZIOSpecDefault { assert(jsonRequest)(equalTo(expected.toJson)) && assert(jsonRequestWithUpsert)(equalTo(expectedWithUpsert.toJson)) }, - test("update - script") { - val jsonRequest = updateByScript(index = Index, id = DocId, script = Script1) match { - case r: Update => r.toJson + test("updateByQuery") { + val jsonRequest = updateAllByQuery(index = Index, script = Script1) match { + case r: UpdateByQuery => r.toJson } - val jsonRequestWithUpsert = updateByScript(index = Index, id = DocId, script = Script1).orCreate(Doc2) match { - case r: Update => r.toJson + val jsonRequestWithQuery = updateByQuery(index = Index, query = Query, script = Script1) match { + case r: UpdateByQuery => r.toJson } val expected = @@ -1554,7 +1612,7 @@ object ElasticRequestSpec extends ZIOSpecDefault { |} |""".stripMargin - val expectedWithUpsert = + val expectedWithQuery = """ |{ | "script": { @@ -1563,30 +1621,25 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "factor": 2 | } | }, - | "upsert": { - | "stringField": "stringField2", - | "subDocumentList": [], - | "dateField": "2022-10-10", - | "intField": 10, - | "doubleField": 17.0, - | "booleanField": false, - | "geoPointField": { - | "lat": 10.0, - | "lon": 11.0 + | "query" : { + | "range" : { + | "intField" : { + | "gte" : 10 + | } | } | } |} |""".stripMargin assert(jsonRequest)(equalTo(expected.toJson)) && - assert(jsonRequestWithUpsert)(equalTo(expectedWithUpsert.toJson)) + assert(jsonRequestWithQuery)(equalTo(expectedWithQuery.toJson)) }, - test("updateByQuery") { - val jsonRequest = updateAllByQuery(index = Index, script = Script1) match { - case r: UpdateByQuery => r.toJson + test("updateByScript") { + val jsonRequest = updateByScript(index = Index, id = DocId, script = Script1) match { + case r: Update => r.toJson } - val jsonRequestWithQuery = updateByQuery(index = Index, query = Query, script = Script1) match { - case r: UpdateByQuery => r.toJson + val jsonRequestWithUpsert = updateByScript(index = Index, id = DocId, script = Script1).orCreate(Doc2) match { + case r: Update => r.toJson } val expected = @@ -1601,7 +1654,7 @@ object ElasticRequestSpec extends ZIOSpecDefault { |} |""".stripMargin - val expectedWithQuery = + val expectedWithUpsert = """ |{ | "script": { @@ -1610,18 +1663,48 @@ object ElasticRequestSpec extends ZIOSpecDefault { | "factor": 2 | } | }, - | "query" : { - | "range" : { - | "intField" : { - | "gte" : 10 - | } - | } + | "upsert": { + | "stringField": "stringField2", + | "subDocumentList": [], + | "dateField": "2022-10-10", + | "intField": 10, + | "doubleField": 17.0, + | "booleanField": false, + | "geoPointField": { + | "lat": 10.0, + | "lon": 11.0 + | }, + | "vectorField": [] | } |} |""".stripMargin assert(jsonRequest)(equalTo(expected.toJson)) && - assert(jsonRequestWithQuery)(equalTo(expectedWithQuery.toJson)) + assert(jsonRequestWithUpsert)(equalTo(expectedWithUpsert.toJson)) + }, + test("upsert") { + val jsonRequest = upsert(index = Index, id = DocId, doc = Doc1) match { + case r: CreateOrUpdate => r.toJson + } + + val expected = + """ + |{ + | "stringField": "stringField1", + | "subDocumentList": [], + | "dateField": "2020-10-10", + | "intField": 5, + | "doubleField": 7.0, + | "booleanField": true, + | "geoPointField": { + | "lat": 20.0, + | "lon": 21.0 + | }, + | "vectorField": [] + |} + |""".stripMargin + + assert(jsonRequest)(equalTo(expected.toJson)) } ) ) @@ -1633,7 +1716,8 @@ object ElasticRequestSpec extends ZIOSpecDefault { intField = 5, doubleField = 7.0, booleanField = true, - geoPointField = GeoPoint(20.0, 21.0) + geoPointField = GeoPoint(20.0, 21.0), + vectorField = List() ) private val Doc2 = TestDocument( stringField = "stringField2", @@ -1642,13 +1726,15 @@ object ElasticRequestSpec extends ZIOSpecDefault { intField = 10, doubleField = 17.0, booleanField = false, - geoPointField = GeoPoint(10.0, 11.0) + geoPointField = GeoPoint(10.0, 11.0), + vectorField = List() ) private val DocId = DocumentId("documentid") private val Index = IndexName("index") private val MaxAggregation = ElasticAggregation.maxAggregation(name = "aggregation", field = TestDocument.intField) private val Indices = MultiIndex.names(Index, IndexName("index2")) private val Query = ElasticQuery.range(TestDocument.intField).gte(10) + private val KnnQuery = ElasticQuery.kNN(TestDocument.stringField, 10, 21, Chunk(1.1, 3.3)) private val RoutingValue = Routing("routing") private val Script1 = Script("doc['intField'].value * params['factor']").params("factor" -> 2) private val TermsAggregation = termsAggregation(name = "aggregation", field = "intField") diff --git a/modules/library/src/test/scala/zio/elasticsearch/HttpElasticExecutorSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/HttpElasticExecutorSpec.scala index 13555fe80..9ee167632 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/HttpElasticExecutorSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/HttpElasticExecutorSpec.scala @@ -18,7 +18,7 @@ package zio.elasticsearch import zio.Chunk import zio.elasticsearch.ElasticAggregation.termsAggregation -import zio.elasticsearch.ElasticQuery.{matchAll, term} +import zio.elasticsearch.ElasticQuery.{kNN, matchAll, term} import zio.elasticsearch.domain.TestDocument import zio.elasticsearch.executor.Executor import zio.elasticsearch.executor.response.{BulkResponse, CreateBulkResponse, Shards} @@ -176,6 +176,16 @@ object HttpElasticExecutorSpec extends SttpBackendStubSpec { assertZIO(executorGetById)(isSome(equalTo(doc))) }, + test("knnSearch") { + val executorSearch = + Executor + .execute( + ElasticRequest + .knnSearch(selectors = index, query = kNN(TestDocument.vectorField, 2, 5, Chunk(-5.0, 9.0, -12.0))) + ) + .documentAs[TestDocument] + assertZIO(executorSearch)(equalTo(Chunk(doc))) + }, test("refresh") { val executorRefresh = Executor.execute(ElasticRequest.refresh(selectors = index)) assertZIO(executorRefresh)(equalTo(true)) diff --git a/modules/library/src/test/scala/zio/elasticsearch/SttpBackendStubSpec.scala b/modules/library/src/test/scala/zio/elasticsearch/SttpBackendStubSpec.scala index 851131176..d055c6f1a 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/SttpBackendStubSpec.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/SttpBackendStubSpec.scala @@ -43,7 +43,8 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { intField = 10, doubleField = 10.0, booleanField = true, - geoPointField = GeoPoint(1.0, 1.0) + geoPointField = GeoPoint(1.0, 1.0), + vectorField = List(1, 5, -20) ) val secondDoc: TestDocument = @@ -54,7 +55,8 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { intField = 12, doubleField = 12.0, booleanField = true, - geoPointField = GeoPoint(1.0, 1.0) + geoPointField = GeoPoint(1.0, 1.0), + vectorField = List() ) private val url = "http://localhost:9200" @@ -209,7 +211,71 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 1.0, | "lon": 1.0 - | } + | }, + | "vectorField": [ + | 1, + | 5, + | -20 + | ] + | } + |}""".stripMargin, + StatusCode.Ok + ) + ) + + private val knnSearchStub: StubMapping = StubMapping( + request = r => r.method == Method.POST && r.uri.toString == s"$url/repositories/_knn_search", + response = Response( + """ + |{ + | "took": 5, + | "timed_out": false, + | "_shards": { + | "total": 8, + | "successful": 8, + | "failed": 0 + | }, + | "hits": { + | "total": { + | "value": 2, + | "relation": "eq" + | }, + | "max_score": 0.008547009, + | "hits": [ + | { + | "_index": "repositories", + | "_type": "type", + | "_id": "111", + | "_score": 0.008547009, + | "_source": { + | "stringField": "StringField", + | "subDocumentList": [ + | { + | "stringField": "StringField", + | "nestedField": { + | "stringField": "StringField", + | "longField": 1 + | }, + | "intField": 132, + | "intFieldList": [] + | } + | ], + | "dateField": "2020-10-11", + | "intField": 10, + | "doubleField": 10.0, + | "booleanField": true, + | "geoPointField": { + | "lat": 1.0, + | "lon": 1.0 + | }, + | "vectorField": [ + | 1, + | 5, + | -20 + | ] + | } + | } + | ] | } |}""".stripMargin, StatusCode.Ok @@ -265,7 +331,12 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 1.0, | "lon": 1.0 - | } + | }, + | "vectorField": [ + | 1, + | 5, + | -20 + | ] | } | } | ] @@ -319,7 +390,12 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { | "geoPointField": { | "lat": 1.0, | "lon": 1.0 - | } + | }, + | "vectorField": [ + | 1, + | 5, + | -20 + | ] | } | } | ] @@ -341,13 +417,13 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { ) ) - private val UpdateRequestStub: StubMapping = StubMapping( + private val updateRequestStub: StubMapping = StubMapping( request = r => r.method == Method.POST && r.uri.toString == s"$url/repositories/_update/V4x8q4UB3agN0z75fv5r?refresh=true&routing=routing", response = Response("Updated", StatusCode.Ok) ) - private val UpdateByQueryRequestStub: StubMapping = StubMapping( + private val updateByQueryRequestStub: StubMapping = StubMapping( request = r => r.method == Method.POST && r.uri.toString == s"$url/repositories/_update_by_query?conflicts=proceed&refresh=true&routing=routing", response = Response( @@ -377,11 +453,12 @@ trait SttpBackendStubSpec extends ZIOSpecDefault { deleteIndexRequestStub, existsRequestStub, getByIdRequestStub, + knnSearchStub, refreshRequestStub, searchRequestStub, searchWithAggregationRequestStub, - UpdateRequestStub, - UpdateByQueryRequestStub + updateRequestStub, + updateByQueryRequestStub ) private val sttpBackendStubLayer: TaskLayer[SttpBackendStub[Task, Any]] = ZLayer.succeed( diff --git a/modules/library/src/test/scala/zio/elasticsearch/domain/TestDocument.scala b/modules/library/src/test/scala/zio/elasticsearch/domain/TestDocument.scala index b7f6fe8e7..c3578f48d 100644 --- a/modules/library/src/test/scala/zio/elasticsearch/domain/TestDocument.scala +++ b/modules/library/src/test/scala/zio/elasticsearch/domain/TestDocument.scala @@ -13,14 +13,16 @@ final case class TestDocument( intField: Int, doubleField: Double, booleanField: Boolean, - geoPointField: GeoPoint + geoPointField: GeoPoint, + vectorField: List[Int] ) object TestDocument { - implicit val schema - : Schema.CaseClass7[String, List[TestSubDocument], LocalDate, Int, Double, Boolean, GeoPoint, TestDocument] = + implicit val schema: Schema.CaseClass8[String, List[TestSubDocument], LocalDate, Int, Double, Boolean, GeoPoint, List[ + Int + ], TestDocument] = DeriveSchema.gen[TestDocument] - val (stringField, subDocumentList, dateField, intField, doubleField, booleanField, geoPointField) = + val (stringField, subDocumentList, dateField, intField, doubleField, booleanField, geoPointField, vectorField) = schema.makeAccessors(FieldAccessorBuilder) }