Skip to content

Commit

Permalink
update chroma client
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Sep 29, 2024
1 parent ba7be56 commit 5f6b4e9
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 18 deletions.
19 changes: 10 additions & 9 deletions examples/refresh_chroma/refresh_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from raggy.loaders.base import Loader
from raggy.loaders.github import GitHubRepoLoader
from raggy.loaders.web import SitemapLoader
from raggy.vectorstores.chroma import Chroma
from raggy.vectorstores.chroma import Chroma, ChromaClientType


def html_parser(html: str) -> str:
Expand All @@ -21,6 +21,7 @@ def html_parser(html: str) -> str:

raggy.settings.html_parser = html_parser


prefect_loaders = [
SitemapLoader(
urls=[
Expand All @@ -31,7 +32,7 @@ def html_parser(html: str) -> str:
),
GitHubRepoLoader(
repo="PrefectHQ/prefect",
include_globs=["flows/"],
include_globs=["README.md"],
),
]

Expand All @@ -43,7 +44,7 @@ def html_parser(html: str) -> str:
cache_expiration=timedelta(days=1),
task_run_name="Run {loader.__class__.__name__}",
persist_result=True,
# refresh_cache=True,
refresh_cache=True,
)
async def run_loader(loader: Loader) -> list[Document]:
return await loader.load()
Expand All @@ -52,12 +53,14 @@ async def run_loader(loader: Loader) -> list[Document]:
@flow(name="Update Knowledge", log_prints=True)
async def refresh_chroma(
collection_name: str = "default",
chroma_client_type: Literal["base", "http"] = "base",
chroma_client_type: ChromaClientType = "base",
mode: Literal["upsert", "reset"] = "upsert",
):
"""Flow updating vectorstore with info from the Prefect community."""
documents = [
doc for future in run_loader.map(prefect_loaders) for doc in future.result()
doc
for future in run_loader.map(prefect_loaders) # type: ignore
for doc in future.result()
]

print(f"Loaded {len(documents)} documents from the Prefect community.")
Expand All @@ -73,14 +76,12 @@ async def refresh_chroma(
else:
raise ValueError(f"Unknown mode: {mode!r} (expected 'upsert' or 'reset')")

print(f"Added {len(docs)} documents to the {collection_name} collection.")
print(f"Added {len(docs)} documents to the {collection_name} collection.") # type: ignore


if __name__ == "__main__":
import asyncio

asyncio.run(
refresh_chroma(
collection_name="testing", chroma_client_type="base", mode="reset"
)
refresh_chroma(collection_name="docs", chroma_client_type="cloud", mode="reset")
)
5 changes: 3 additions & 2 deletions examples/refresh_tpuf/refresh_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def html_parser(html: str) -> str:

prefect_loaders = [
SitemapLoader(
url_processor=lambda x: x.replace("docs.", "docs-2."),
urls=[
"https://docs-3.prefect.io/sitemap.xml",
"https://docs-2.prefect.io/sitemap.xml",
"https://prefect.io/sitemap.xml",
],
exclude=["api-ref", "www.prefect.io/events"],
Expand Down Expand Up @@ -81,4 +82,4 @@ async def refresh_tpuf_namespace(namespace: str = "testing", reset: bool = False
if __name__ == "__main__":
import asyncio

asyncio.run(refresh_tpuf_namespace(namespace="prefect-3", reset=True)) # type: ignore
asyncio.run(refresh_tpuf_namespace(namespace="prefect-2", reset=True)) # type: ignore
8 changes: 6 additions & 2 deletions src/raggy/loaders/web.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import re
from typing import Self
from typing import Callable, Self
from urllib.parse import urljoin

from bs4 import BeautifulSoup
Expand Down Expand Up @@ -148,12 +148,16 @@ class SitemapLoader(URLLoader):
exclude: list[str | re.Pattern] = Field(default_factory=list)
url_loader: URLLoader = Field(default_factory=HTMLLoader)

url_processor: Callable[[str], str] = lambda x: x # noqa: E731

async def _get_loader(self: Self) -> MultiLoader:
urls = await asyncio.gather(*[self.load_sitemap(url) for url in self.urls])
return MultiLoader(
loaders=[
type(self.url_loader)(urls=url_batch, headers=await self.get_headers()) # type: ignore
for url_batch in batched([u for url_list in urls for u in url_list], 10)
for url_batch in batched(
[self.url_processor(u) for url_list in urls for u in url_list], 10
)
]
)

Expand Down
22 changes: 21 additions & 1 deletion src/raggy/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable

from bs4 import BeautifulSoup
from pydantic import Field, field_validator
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict


Expand Down Expand Up @@ -31,6 +31,24 @@ def default_html_parser(html: str) -> str:
return BeautifulSoup(html, "html.parser").get_text()


class ChromaSettings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="CHROMA_", env_file=".env", extra="ignore"
)

cloud_tenant: str = Field(
default="default",
description="The tenant to use for the Chroma Cloud client.",
)
cloud_database: str = Field(
default="default",
description="The database to use for the Chroma Cloud client.",
)
cloud_api_key: SecretStr = Field(
description="The API key to use for the Chroma Cloud client.",
)


class Settings(BaseSettings):
"""The settings for Raggy.
Expand Down Expand Up @@ -73,6 +91,8 @@ class Settings(BaseSettings):
description="The OpenAI model to use for creating embeddings.",
)

chroma: ChromaSettings = Field(default_factory=ChromaSettings) # type: ignore

@field_validator("log_level", mode="after")
@classmethod
def set_log_level(cls, v):
Expand Down
18 changes: 14 additions & 4 deletions src/raggy/vectorstores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Iterable, Literal

try:
from chromadb import Client, HttpClient
from chromadb import Client, CloudClient, HttpClient
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Include, QueryResult
except ImportError:
Expand All @@ -12,17 +13,26 @@
)

from raggy.documents import Document, get_distinct_documents
from raggy.settings import settings
from raggy.utilities.asyncutils import run_sync_in_worker_thread
from raggy.utilities.embeddings import create_openai_embeddings
from raggy.utilities.text import slice_tokens
from raggy.vectorstores.base import Vectorstore

ChromaClientType = Literal["base", "http", "cloud"]

def get_client(client_type: Literal["base", "http"]) -> HttpClient:

def get_client(client_type: ChromaClientType) -> ClientAPI:
if client_type == "base":
return Client()
elif client_type == "http":
return HttpClient()
elif client_type == "cloud":
return CloudClient(
tenant=settings.chroma.cloud_tenant,
database=settings.chroma.cloud_database,
api_key=settings.chroma.cloud_api_key.get_secret_value(),
)
else:
raise ValueError(f"Unknown client type: {client_type}")

Expand All @@ -46,7 +56,7 @@ class Chroma(Vectorstore):
"""

client_type: Literal["base", "http"] = "base"
client_type: ChromaClientType = "base"
collection_name: str = "raggy"

@property
Expand Down Expand Up @@ -141,7 +151,7 @@ async def reset_collection(self):
await run_sync_in_worker_thread(
client.delete_collection, self.collection_name
)
except ValueError:
except Exception:
self.logger.warning_kv(
"Collection not found",
f"Creating a new collection {self.collection_name!r}",
Expand Down

0 comments on commit 5f6b4e9

Please sign in to comment.