From ba3a246ec2492357bd2bd68cf240a5911ad21e0b Mon Sep 17 00:00:00 2001 From: Philippus Baalman Date: Sat, 17 Aug 2024 13:02:28 -0500 Subject: [PATCH] Add Int4Flat and Int4Hnsw KnnTypes (#3121) * Add Int4Flat and Int4Hnsw KnnTypes * Match on case object name values instead of strings --- .../mappings/DenseVectorFieldTest.scala | 8 +++-- .../elastic4s/fields/DenseVectorField.scala | 2 ++ .../fields/DenseVectorFieldBuilderFn.scala | 36 ++++++++++++------- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala index 0ba3fc216..bae4ebf0a 100644 --- a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala @@ -1,9 +1,9 @@ package com.sksamuel.elastic4s.requests.mappings -import com.sksamuel.elastic4s.fields.DenseVectorField.{Flat, Hnsw, Int8Flat, Int8Hnsw} +import com.sksamuel.elastic4s.fields.DenseVectorField.{Flat, Hnsw, Int4Flat, Int4Hnsw, Int8Flat, Int8Hnsw} import com.sksamuel.elastic4s.ElasticApi import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct} -import com.sksamuel.elastic4s.handlers.fields.{DenseVectorFieldBuilderFn} +import com.sksamuel.elastic4s.handlers.fields.DenseVectorFieldBuilderFn import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -50,11 +50,15 @@ class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi { val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions)) DenseVectorFieldBuilderFn.build(field).string shouldBe """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}""" + DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int4Hnsw))).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int4_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}""" DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Hnsw))).string shouldBe """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"hnsw","m":10,"ef_construction":100}}""" DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Flat))).string shouldBe """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"flat"}}""" DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int8Flat))).string shouldBe """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_flat","confidence_interval":1.0}}""" + DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int4Flat))).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int4_flat","confidence_interval":1.0}}""" } } diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala index 7c91c1e3c..312cbf81b 100644 --- a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala @@ -8,8 +8,10 @@ object DenseVectorField { } case object Hnsw extends KnnType { val name = "hnsw" } case object Int8Hnsw extends KnnType { val name = "int8_hnsw" } + case object Int4Hnsw extends KnnType { val name = "int4_hnsw" } case object Flat extends KnnType { val name = "flat" } case object Int8Flat extends KnnType { val name = "int8_flat" } + case object Int4Flat extends KnnType { val name = "int4_flat" } @deprecated("Use the new apply method", "8.14.0") def apply(name: String, diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala index 1fa27c03d..07b745e60 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala @@ -1,39 +1,51 @@ package com.sksamuel.elastic4s.handlers.fields -import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int8Flat, Int8Hnsw} +import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int4Flat, Int4Hnsw, Int8Flat, Int8Hnsw} import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct, Similarity} import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory} object DenseVectorFieldBuilderFn { private def similarityFromString(similarity: String): Similarity = similarity match { - case "l2_norm" => L2Norm - case "dot_product" => DotProduct - case "cosine" => Cosine - case "max_inner_product" => MaxInnerProduct + case L2Norm.name => L2Norm + case DotProduct.name => DotProduct + case Cosine.name => Cosine + case MaxInnerProduct.name => MaxInnerProduct } private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions = values("type").asInstanceOf[String] match { - case "hnsw" => DenseVectorIndexOptions( + case DenseVectorField.Hnsw.name => DenseVectorIndexOptions( DenseVectorField.Hnsw, values.get("m").map(_.asInstanceOf[Int]), values.get("ef_construction").map(_.asInstanceOf[Int]) ) - case "int8_hnsw" => DenseVectorIndexOptions( + case DenseVectorField.Int8Hnsw.name => DenseVectorIndexOptions( DenseVectorField.Int8Hnsw, values.get("m").map(_.asInstanceOf[Int]), values.get("ef_construction").map(_.asInstanceOf[Int]), values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) ) - case "flat" => DenseVectorIndexOptions( + case DenseVectorField.Int4Hnsw.name => DenseVectorIndexOptions( + DenseVectorField.Int4Hnsw, + values.get("m").map(_.asInstanceOf[Int]), + values.get("ef_construction").map(_.asInstanceOf[Int]), + values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) + ) + case DenseVectorField.Flat.name => DenseVectorIndexOptions( DenseVectorField.Flat ) - case "int8_flat" => DenseVectorIndexOptions( + case DenseVectorField.Int8Flat.name => DenseVectorIndexOptions( DenseVectorField.Int8Flat, None, None, values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) ) + case DenseVectorField.Int4Flat.name => DenseVectorIndexOptions( + DenseVectorField.Int4Flat, + None, + None, + values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) + ) } def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField( @@ -56,9 +68,9 @@ object DenseVectorFieldBuilderFn { field.indexOptions.foreach { options => builder.startObject("index_options") builder.field("type", options.`type`.name) - if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.m.foreach(builder.field("m", _)) - if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.efConstruction.foreach(builder.field("ef_construction", _)) - if (Seq(Int8Hnsw, Int8Flat).contains(options.`type`)) options.confidenceInterval.foreach(builder.field("confidence_interval", _)) + if (Seq(Hnsw, Int8Hnsw, Int4Hnsw).contains(options.`type`)) options.m.foreach(builder.field("m", _)) + if (Seq(Hnsw, Int8Hnsw, Int4Hnsw).contains(options.`type`)) options.efConstruction.foreach(builder.field("ef_construction", _)) + if (Seq(Int8Hnsw, Int4Hnsw, Int8Flat, Int4Flat).contains(options.`type`)) options.confidenceInterval.foreach(builder.field("confidence_interval", _)) builder.endObject() } }