Skip to content

Commit

Permalink
Cleanup + Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ejscribner committed Sep 12, 2024
1 parent 0aef582 commit fa2ccfc
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 119 deletions.
2 changes: 1 addition & 1 deletion api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,6 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
CouchbaseConfig
CouchbaseConfig,
):
pass
12 changes: 5 additions & 7 deletions api/configs/middleware/vdb/couchbase_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@ class CouchbaseConfig(BaseModel):
"""

COUCHBASE_CONNECTION_STRING: Optional[str] = Field(
description='COUCHBASE connection string',
description="COUCHBASE connection string",
default=None,
)

COUCHBASE_USER: Optional[str] = Field(
description='COUCHBASE user',
description="COUCHBASE user",
default=None,
)

COUCHBASE_PASSWORD: Optional[str] = Field(
description='COUCHBASE password',
description="COUCHBASE password",
default=None,
)

COUCHBASE_BUCKET_NAME: Optional[str] = Field(
description='COUCHBASE bucket name',
description="COUCHBASE bucket name",
default=None,

)

COUCHBASE_SCOPE_NAME: Optional[str] = Field(
description='COUCHBASE scope name',
description="COUCHBASE scope name",
default=None,

)
158 changes: 70 additions & 88 deletions api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ class CouchbaseConfig(BaseModel):
bucket_name: str
scope_name: str

@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get('connection_string'):
if not values.get("connection_string"):
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
if not values.get('user'):
if not values.get("user"):
raise ValueError("config COUCHBASE_USER is required")
if not values.get('password'):
if not values.get("password"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get('bucket_name'):
if not values.get("bucket_name"):
raise ValueError("config COUCHBASE_PASSWORD is required")
if not values.get('scope_name'):
if not values.get("scope_name"):
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
return values

class CouchbaseVector(BaseVector):


class CouchbaseVector(BaseVector):
def __init__(self, collection_name: str, config: CouchbaseConfig):
super().__init__(collection_name)
self._client_config = config
Expand All @@ -68,14 +69,14 @@ def __init__(self, collection_name: str, config: CouchbaseConfig):
self._cluster.wait_until_ready(timedelta(seconds=5))

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_id = str(uuid.uuid4()).replace('-','')
self._create_collection(uuid=index_id,vector_length=len(embeddings[0]))
index_id = str(uuid.uuid4()).replace("-", "")
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
self.add_texts(texts, embeddings)

def _create_collection(self, vector_length: int, uuid: str):
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
if self._collection_exists(self._collection_name):
Expand Down Expand Up @@ -165,10 +166,14 @@ def _create_collection(self, vector_length: int, uuid: str):
"sourceParams": { }
}
""")
index_definition['name'] = self._collection_name + '_search'
index_definition['uuid'] = uuid
index_definition['params']['mapping']['types']['collection_name']['properties']['embedding']['fields'][0]['dims'] = vector_length
index_definition['params']['mapping']['types'][self._scope_name + '.' + self._collection_name] = index_definition['params']['mapping']['types'].pop('collection_name')
index_definition["name"] = self._collection_name + "_search"
index_definition["uuid"] = uuid
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
"dims"
] = vector_length
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
index_definition["params"]["mapping"]["types"].pop("collection_name")
)
time.sleep(2)
index_manager.upsert_index(
SearchIndex(
Expand Down Expand Up @@ -206,32 +211,27 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
doc_ids = []

documents_to_insert = [
{
'text': text,
'embedding': vector,
'metadata': metadata
}
for id, text, vector, metadata in zip(
uuids, texts, embeddings, metadatas
)
{"text": text, "embedding": vector, "metadata": metadata}
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
]
for doc,id in zip(documents_to_insert,uuids):
result = self._scope.collection(self._collection_name).upsert(id,doc)
for doc, id in zip(documents_to_insert, uuids):
result = self._scope.collection(self._collection_name).upsert(id, doc)




doc_ids.extend(uuids)

return doc_ids

def text_exists(self, id: str) -> bool:
# Use a parameterized query for safety and correctness
query = f"SELECT COUNT(1) AS count FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} WHERE META().id = $doc_id"
query = f"""
SELECT COUNT(1) AS count FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id
"""
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id})
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return row['count'] > 0
return row["count"] > 0
return False # Return False if no rows are returned

def delete_by_ids(self, ids: list[str]) -> None:
Expand All @@ -240,72 +240,61 @@ def delete_by_ids(self, ids: list[str]) -> None:
WHERE META().id IN $doc_ids;
"""
try:
result = self._cluster.query(query, named_parameters={'doc_ids': ids})
# force evaluation of the query to ensure deletion occurs
list(result)
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
except Exception as e:
logger.error(e)

def delete_by_document_id(self, document_id: str):
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
DELETE FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE META().id = $doc_id;
"""
result = self._cluster.query(query,named_parameters={'doc_id':document_id})
# force evaluation of the query to ensure deletion occurs
list(result)
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()

# def get_ids_by_metadata_field(self, key: str, value: str):
# query = f"""
# SELECT id FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
# SELECT id FROM
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
# WHERE `metadata.{key}` = $value;
# """
# result = self._cluster.query(query, named_parameters={'value':value})
# return [row['id'] for row in result.rows()]


def delete_by_metadata_field(self, key: str, value: str) -> None:
query = f"""
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value;
"""
result = self._cluster.query(query, named_parameters={'value':value})
# force evaluation of the query to ensure deletion occurs
list(result)

def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
self._cluster.query(query, named_parameters={"value": value}).execute()

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold") or 0.0

search_req = search.SearchRequest.create(
VectorSearch.from_vector_query(
VectorQuery(
'embedding',
"embedding",
query_vector,
top_k,

)
)
)
)
try:
search_iter = self._scope.search(
self._collection_name + '_search',
search_req,
SearchOptions(limit=top_k, collections=[self._collection_name],fields=['*']),
)
self._collection_name + "_search",
search_req,
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
)

docs = []
# Parse the results
for row in search_iter.rows():
text = row.fields.pop('text')
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata['score'] = score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
if score >= score_threshold:
docs.append(doc)
Expand All @@ -314,41 +303,36 @@ def search_by_vector(

return docs

def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
top_k=kwargs.get('top_k', 2)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery('text:'+query))
search_iter = self._scope.search(self._collection_name + '_search',
CBrequest,
SearchOptions(limit=top_k,fields=['*']))
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)


docs = []
for row in search_iter.rows():
text = row.fields.pop('text')
text = row.fields.pop("text")
metadata = self._format_metadata(row.fields)
score = row.score
metadata['score'] = score
metadata["score"] = score
doc = Document(page_content=text, metadata=metadata)
docs.append(doc)

except Exception as e:
raise ValueError(f"Search failed with error: {e}")

return docs

def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()


for scope in scopes:
for collection in scope.collections:
if collection.name == self._collection_name:
manager.drop_collection('_default', self._collection_name)
manager.drop_collection("_default", self._collection_name)

def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
"""Helper method to format the metadata from the Couchbase Search API.
Expand All @@ -362,16 +346,15 @@ def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
for key, value in row_fields.items():
# Couchbase Search returns the metadata key with a prefix
# `metadata.` We remove it to get the original metadata key
if key.startswith('metadata'):
new_key = key.split('metadata' + ".")[-1]
if key.startswith("metadata"):
new_key = key.split("metadata" + ".")[-1]
metadata[new_key] = value
else:
metadata[key] = value

return metadata



class CouchbaseVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
if dataset.index_struct_dict:
Expand All @@ -380,17 +363,16 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))

config = current_app.config
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get('COUCHBASE_CONNECTION_STRING'),
user=config.get('COUCHBASE_USER'),
password=config.get('COUCHBASE_PASSWORD'),
bucket_name=config.get('COUCHBASE_BUCKET_NAME'),
scope_name=config.get('COUCHBASE_SCOPE_NAME'),
)
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
user=config.get("COUCHBASE_USER"),
password=config.get("COUCHBASE_PASSWORD"),
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
),
)
18 changes: 0 additions & 18 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class VectorType(str, Enum):
<<<<<<< HEAD
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"
Expand All @@ -18,20 +17,3 @@ class VectorType(str, Enum):
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
COUCHBASE = "couchbase"
=======
ANALYTICDB = 'analyticdb'
CHROMA = 'chroma'
MILVUS = 'milvus'
MYSCALE = 'myscale'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'
COUCHBASE = 'couchbase'
>>>>>>> 8d7e8c48 (Cleanup)
11 changes: 6 additions & 5 deletions api/tests/integration_tests/vdb/couchbase/test_couchbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def __init__(self):
self.vector = CouchbaseVector(
collection_name=self.collection_name,
config=CouchbaseConfig(
connection_string = '127.0.0.1',
user = 'Administrator',
password = 'password',
bucket_name = 'Embeddings',
scope_name = '_default',
connection_string="127.0.0.1",
user="Administrator",
password="password",
bucket_name="Embeddings",
scope_name="_default",
),
)


def test_couchbase(setup_mock_redis):
CouchbaseTest().run_all_tests()

0 comments on commit fa2ccfc

Please sign in to comment.