Skip to content

Commit

Permalink
Merge pull request #107 from weaviate/fix-cache
Browse files Browse the repository at this point in the history
Fix unhashable type: 'VectorInputConfig' error with cache enabled
  • Loading branch information
antas-marcin authored Feb 1, 2025
2 parents 61c4560 + 073b181 commit a22f5d7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
8 changes: 5 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def log_info_about_onnx(onnx_runtime: bool):
cuda_support = True
cuda_core = os.getenv("CUDA_CORE")
if cuda_core is None or cuda_core == "":
if use_sentence_transformers_vectorizer and use_sentence_transformers_multi_process and torch.cuda.is_available():
if (
use_sentence_transformers_vectorizer
and use_sentence_transformers_multi_process
and torch.cuda.is_available()
):
available_workers = torch.cuda.device_count()
cuda_core = ",".join([f"cuda:{i}" for i in range(available_workers)])
else:
Expand All @@ -127,8 +131,6 @@ def log_info_about_onnx(onnx_runtime: bool):
else:
logger.info("Running on CPU")



# Batch text tokenization enabled by default
direct_tokenize = get_t2v_transformers_direct_tokenize()

Expand Down
8 changes: 8 additions & 0 deletions cicd/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ python3 smoke_auth_test.py

docker stop $container_id

echo "Running tests with enabled cache"

container_id=$(docker run -d -it -e ENABLE_CACHE='1' -p "8000:8080" "$local_repo")

python3 smoke_validate_cache_test.py

docker stop $container_id

echo "Running tests without authorization"

container_id=$(docker run -d -it -p "8000:8080" "$local_repo")
Expand Down
11 changes: 11 additions & 0 deletions smoke_validate_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def _test_vectorizing_sentences(self):
self._try_to_vectorize(self.url + "/vectors/", sentence)
self._try_to_vectorize(self.url + "/vectors", sentence)

def test_vectorize_payload_with_config(self):
weaviate_facts = [
"Vector database for semantic search.",
"Supports similarity-based queries.",
"Integrates with ML for classification.",
]
for _ in range(10):
for fact in weaviate_facts:
self._try_to_vectorize(self.url + "/vectors/", fact, "query")
self._try_to_vectorize(self.url + "/vectors", fact, "passage")

def test_vectorizing_cached_results(self):
start = time.time()
before = {}
Expand Down
19 changes: 19 additions & 0 deletions vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,30 @@ class VectorInputConfig(BaseModel):
pooling_strategy: Optional[str] = None
task_type: Optional[str] = None

def __hash__(self):
return hash((self.pooling_strategy, self.task_type))

def __eq__(self, other):
if isinstance(other, VectorInputConfig):
return (
self.pooling_strategy == other.pooling_strategy
and self.task_type == other.task_type
)
return False


class VectorInput(BaseModel):
text: str
config: Optional[VectorInputConfig] = None

def __hash__(self):
return hash((self.text, self.config))

def __eq__(self, other):
if isinstance(other, VectorInput):
return self.text == other.text and self.config == other.config
return False


class Vectorizer:
executor: ThreadPoolExecutor
Expand Down

0 comments on commit a22f5d7

Please sign in to comment.