Skip to content

Commit

Permalink
Add Int4Flat and Int4Hnsw KnnTypes (#3121)
Browse files Browse the repository at this point in the history
* Add Int4Flat and Int4Hnsw KnnTypes

* Match on case object name values instead of strings
  • Loading branch information
Philippus authored Aug 17, 2024
1 parent 420748d commit ba3a246
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.sksamuel.elastic4s.requests.mappings

import com.sksamuel.elastic4s.fields.DenseVectorField.{Flat, Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.fields.DenseVectorField.{Flat, Hnsw, Int4Flat, Int4Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.ElasticApi
import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct}
import com.sksamuel.elastic4s.handlers.fields.{DenseVectorFieldBuilderFn}
import com.sksamuel.elastic4s.handlers.fields.DenseVectorFieldBuilderFn
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -50,11 +50,15 @@ class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi {
val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions))
DenseVectorFieldBuilderFn.build(field).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int4Hnsw))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int4_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Hnsw))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"hnsw","m":10,"ef_construction":100}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Flat))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"flat"}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int8Flat))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_flat","confidence_interval":1.0}}"""
DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int4Flat))).string shouldBe
"""{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int4_flat","confidence_interval":1.0}}"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ object DenseVectorField {
}
case object Hnsw extends KnnType { val name = "hnsw" }
case object Int8Hnsw extends KnnType { val name = "int8_hnsw" }
case object Int4Hnsw extends KnnType { val name = "int4_hnsw" }
case object Flat extends KnnType { val name = "flat" }
case object Int8Flat extends KnnType { val name = "int8_flat" }
case object Int4Flat extends KnnType { val name = "int4_flat" }

@deprecated("Use the new apply method", "8.14.0")
def apply(name: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,51 @@
package com.sksamuel.elastic4s.handlers.fields

import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int4Flat, Int4Hnsw, Int8Flat, Int8Hnsw}
import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct, Similarity}
import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory}

object DenseVectorFieldBuilderFn {
private def similarityFromString(similarity: String): Similarity = similarity match {
case "l2_norm" => L2Norm
case "dot_product" => DotProduct
case "cosine" => Cosine
case "max_inner_product" => MaxInnerProduct
case L2Norm.name => L2Norm
case DotProduct.name => DotProduct
case Cosine.name => Cosine
case MaxInnerProduct.name => MaxInnerProduct
}

private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions =
values("type").asInstanceOf[String] match {
case "hnsw" => DenseVectorIndexOptions(
case DenseVectorField.Hnsw.name => DenseVectorIndexOptions(
DenseVectorField.Hnsw,
values.get("m").map(_.asInstanceOf[Int]),
values.get("ef_construction").map(_.asInstanceOf[Int])
)
case "int8_hnsw" => DenseVectorIndexOptions(
case DenseVectorField.Int8Hnsw.name => DenseVectorIndexOptions(
DenseVectorField.Int8Hnsw,
values.get("m").map(_.asInstanceOf[Int]),
values.get("ef_construction").map(_.asInstanceOf[Int]),
values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat)
)
case "flat" => DenseVectorIndexOptions(
case DenseVectorField.Int4Hnsw.name => DenseVectorIndexOptions(
DenseVectorField.Int4Hnsw,
values.get("m").map(_.asInstanceOf[Int]),
values.get("ef_construction").map(_.asInstanceOf[Int]),
values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat)
)
case DenseVectorField.Flat.name => DenseVectorIndexOptions(
DenseVectorField.Flat
)
case "int8_flat" => DenseVectorIndexOptions(
case DenseVectorField.Int8Flat.name => DenseVectorIndexOptions(
DenseVectorField.Int8Flat,
None,
None,
values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat)
)
case DenseVectorField.Int4Flat.name => DenseVectorIndexOptions(
DenseVectorField.Int4Flat,
None,
None,
values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat)
)
}

def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField(
Expand All @@ -56,9 +68,9 @@ object DenseVectorFieldBuilderFn {
field.indexOptions.foreach { options =>
builder.startObject("index_options")
builder.field("type", options.`type`.name)
if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.m.foreach(builder.field("m", _))
if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.efConstruction.foreach(builder.field("ef_construction", _))
if (Seq(Int8Hnsw, Int8Flat).contains(options.`type`)) options.confidenceInterval.foreach(builder.field("confidence_interval", _))
if (Seq(Hnsw, Int8Hnsw, Int4Hnsw).contains(options.`type`)) options.m.foreach(builder.field("m", _))
if (Seq(Hnsw, Int8Hnsw, Int4Hnsw).contains(options.`type`)) options.efConstruction.foreach(builder.field("ef_construction", _))
if (Seq(Int8Hnsw, Int4Hnsw, Int8Flat, Int4Flat).contains(options.`type`)) options.confidenceInterval.foreach(builder.field("confidence_interval", _))
builder.endObject()
}
}
Expand Down

0 comments on commit ba3a246

Please sign in to comment.