Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a default for the dtype field of BaseVectorizer #261

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions redisvl/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ def _make_field(storage_type, **field_inputs) -> BaseField:

@root_validator(pre=True)
@classmethod
def validate_and_create_fields(cls, values):
def validate_and_create_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate uniqueness of field names and create valid field instances.
"""
# Ensure index is a dictionary for validation
index = values.get("index")
index = values.get("index", {})
if not isinstance(index, IndexInfo):
index = IndexInfo(**index)

Expand Down
4 changes: 2 additions & 2 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Callable, List, Optional

from pydantic.v1 import BaseModel, validator
from pydantic.v1 import BaseModel, Field, validator

from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType
Expand All @@ -21,7 +21,7 @@ class Vectorizers(Enum):
class BaseVectorizer(BaseModel, ABC):
model: str
dims: int
dtype: str
dtype: str = Field(default="float32")

@property
def type(self) -> str:
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_base_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List
from redisvl.utils.vectorize.base import BaseVectorizer



def test_base_vectorizer_defaults():
"""
Test that the base vectorizer defaults are set correctly, with
a default for dtype. Versions before 0.3.8 did not have this field.

A regression test for langchain-redis/#48
"""
class SimpleVectorizer(BaseVectorizer):
model: str = "simple"
dims: int = 10

def embed(self, text: str, **kwargs) -> List[float]:
return [0.0] * self.dims

async def aembed(self, text: str, **kwargs) -> List[float]:
return [0.0] * self.dims

async def aembed_many(self, texts: List[str], **kwargs) -> List[List[float]]:
return [[0.0] * self.dims] * len(texts)

def embed_many(self, texts: List[str], **kwargs) -> List[List[float]]:
return [[0.0] * self.dims] * len(texts)

vectorizer = SimpleVectorizer()
assert vectorizer.model == "simple"
assert vectorizer.dims == 10
assert vectorizer.dtype == "float32"
Loading