diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 025387ec..24ce73c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,7 @@ jobs: "3.12", "3.13", ] - es-version: [8.0.0, 8.15.0] + es-version: [8.0.0, 8.16.0] steps: - name: Remove irrelevant software to free up disk space diff --git a/elasticsearch_dsl/field.py b/elasticsearch_dsl/field.py index d992ecc3..43de7cbf 100644 --- a/elasticsearch_dsl/field.py +++ b/elasticsearch_dsl/field.py @@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float: return float(data) -class DenseVector(Float): +class DenseVector(Field): name = "dense_vector" + _coerce = True def __init__(self, **kwargs: Any): - kwargs["multi"] = True + self._element_type = kwargs.get("element_type", "float") + if self._element_type in ["float", "byte"]: + kwargs["multi"] = True super().__init__(**kwargs) + def _deserialize(self, data: Any) -> Any: + if self._element_type == "float": + return float(data) + elif self._element_type == "byte": + return int(data) + return data + class SparseVector(Field): name = "sparse_vector" diff --git a/tests/test_integration/_async/test_document.py b/tests/test_integration/_async/test_document.py index cbe9a405..6a6ca0fb 100644 --- a/tests/test_integration/_async/test_document.py +++ b/tests/test_integration/_async/test_document.py @@ -23,7 +23,7 @@ from datetime import datetime from ipaddress import ip_address -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union import pytest from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError @@ -37,6 +37,7 @@ Binary, Boolean, Date, + DenseVector, Double, InnerDoc, Ip, @@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]: "age": 45, "languages": ["es"], } + + +@pytest.mark.asyncio +async def test_legacy_dense_vector( + async_client: AsyncElasticsearch, es_version: Tuple[int, ...] +) -> None: + if es_version >= (8, 16): + pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older") + + class Doc(AsyncDocument): + float_vector: List[float] = mapped_field(DenseVector(dims=3)) + + class Index: + name = "vectors" + + await Doc._index.delete(ignore_unavailable=True) + await Doc.init() + + doc = Doc(float_vector=[1.0, 1.2, 2.3]) + await doc.save(refresh=True) + + docs = await Doc.search().execute() + assert len(docs) == 1 + assert docs[0].float_vector == doc.float_vector + + +@pytest.mark.asyncio +async def test_dense_vector( + async_client: AsyncElasticsearch, es_version: Tuple[int, ...] +) -> None: + if es_version < (8, 16): + pytest.skip("this test requires Elasticsearch 8.16 or newer") + + class Doc(AsyncDocument): + float_vector: List[float] = mapped_field(DenseVector()) + byte_vector: List[int] = mapped_field(DenseVector(element_type="byte")) + bit_vector: str = mapped_field(DenseVector(element_type="bit")) + + class Index: + name = "vectors" + + await Doc._index.delete(ignore_unavailable=True) + await Doc.init() + + doc = Doc( + float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0" + ) + await doc.save(refresh=True) + + docs = await Doc.search().execute() + assert len(docs) == 1 + assert docs[0].float_vector == doc.float_vector + assert docs[0].byte_vector == doc.byte_vector + assert docs[0].bit_vector == doc.bit_vector diff --git a/tests/test_integration/_sync/test_document.py b/tests/test_integration/_sync/test_document.py index c36a9931..2f1d64b9 100644 --- a/tests/test_integration/_sync/test_document.py +++ b/tests/test_integration/_sync/test_document.py @@ -23,7 +23,7 @@ from datetime import datetime from ipaddress import ip_address -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union import pytest from elasticsearch import ConflictError, Elasticsearch, NotFoundError @@ -35,6 +35,7 @@ Binary, Boolean, Date, + DenseVector, Document, Double, InnerDoc, @@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]: "age": 45, "languages": ["es"], } + + +@pytest.mark.sync +def test_legacy_dense_vector( + client: Elasticsearch, es_version: Tuple[int, ...] +) -> None: + if es_version >= (8, 16): + pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older") + + class Doc(Document): + float_vector: List[float] = mapped_field(DenseVector(dims=3)) + + class Index: + name = "vectors" + + Doc._index.delete(ignore_unavailable=True) + Doc.init() + + doc = Doc(float_vector=[1.0, 1.2, 2.3]) + doc.save(refresh=True) + + docs = Doc.search().execute() + assert len(docs) == 1 + assert docs[0].float_vector == doc.float_vector + + +@pytest.mark.sync +def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None: + if es_version < (8, 16): + pytest.skip("this test requires Elasticsearch 8.16 or newer") + + class Doc(Document): + float_vector: List[float] = mapped_field(DenseVector()) + byte_vector: List[int] = mapped_field(DenseVector(element_type="byte")) + bit_vector: str = mapped_field(DenseVector(element_type="bit")) + + class Index: + name = "vectors" + + Doc._index.delete(ignore_unavailable=True) + Doc.init() + + doc = Doc( + float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0" + ) + doc.save(refresh=True) + + docs = Doc.search().execute() + assert len(docs) == 1 + assert docs[0].float_vector == doc.float_vector + assert docs[0].byte_vector == doc.byte_vector + assert docs[0].bit_vector == doc.bit_vector