Skip to content

Commit

Permalink
Add filter option (#45)
Browse files Browse the repository at this point in the history
Default list is long now. This would let you do something like

```
registry.filter(Type="RetrievalTask")
```
  • Loading branch information
hinthornw authored Nov 21, 2023
1 parent b3aee9d commit fd0203c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from langchain.schema.retriever import BaseRetriever
from langchain.vectorstores.chroma import Chroma

from langchain_benchmarks.rag.utils._downloading import (
fetch_remote_file,
)
from langchain_benchmarks.rag.utils.indexing import (
get_hyde_retriever,
get_parent_document_retriever,
Expand All @@ -30,6 +33,10 @@ def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]
"Please install pandas to use the langchain docs benchmarking task.\n"
"pip install pandas"
)
if filename is None:
filename = DOCS_FILE
if not os.path.exists(filename):
fetch_remote_file(REMOTE_DOCS_FILE, filename)
df = pd.read_parquet(filename)
docs_transformed = [Document(**row) for row in df.to_dict(orient="records")]
for doc in docs_transformed:
Expand All @@ -43,11 +50,13 @@ def load_docs_from_parquet(filename: Optional[str] = None) -> Iterable[Document]

def _chroma_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transform_docs: Optional[Callable] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-classic-{embedding_name}",
Expand All @@ -67,9 +76,11 @@ def _chroma_retriever_factory(

def _chroma_parent_document_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-parent-doc-{embedding_name}",
Expand All @@ -87,9 +98,11 @@ def _chroma_parent_document_retriever_factory(

def _chroma_hyde_retriever_factory(
embedding: Embeddings,
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
) -> BaseRetriever:
docs = load_docs_from_parquet(DOCS_FILE)
docs = docs or load_docs_from_parquet()
embedding_name = embedding.__class__.__name__
vectorstore = Chroma(
collection_name=f"langchain-benchmarks-hyde-{embedding_name}",
Expand Down
12 changes: 10 additions & 2 deletions langchain_benchmarks/rag/tasks/langchain_docs/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from functools import partial
from typing import Iterable

from langchain.schema.document import Document

from langchain_benchmarks.rag.tasks.langchain_docs import architectures, indexing
from langchain_benchmarks.rag.tasks.langchain_docs.indexing.retriever_registry import (
Expand All @@ -11,12 +13,18 @@
"452ccafc-18e1-4314-885b-edd735f17b9d" # ID of public LangChain Docs dataset
)


def load_cached_docs() -> Iterable[Document]:
"""Load the docs from the cached file."""
return load_docs_from_parquet(DOCS_FILE)


LANGCHAIN_DOCS_TASK = RetrievalTask(
name="LangChain Docs Q&A",
dataset_id=DATASET_ID,
retriever_factories=indexing.RETRIEVER_FACTORIES,
architecture_factories=architectures.ARCH_FACTORIES,
get_docs=partial(load_docs_from_parquet, DOCS_FILE),
get_docs=load_cached_docs,
description=(
"""\
Questions and answers based on a snapshot of the LangChain python docs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _chroma_parent_document_retriever_factory(
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = docs or load_docs()
embedding_name = embedding.__class__.__name__
Expand All @@ -140,6 +141,7 @@ def _chroma_parent_document_retriever_factory(
vectorstore,
collection_name="semi-structured-earnings",
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
transformation_name=transformation_name,
)


Expand All @@ -148,6 +150,7 @@ def _chroma_hyde_retriever_factory(
*,
docs: Optional[Iterable[Document]] = None,
search_kwargs: Optional[dict] = None,
transformation_name: Optional[str] = None,
) -> BaseRetriever:
docs = docs or load_docs()
embedding_name = embedding.__class__.__name__
Expand All @@ -162,6 +165,7 @@ def _chroma_hyde_retriever_factory(
vectorstore,
collection_name="semi-structured-earnings",
search_kwargs=search_kwargs or _DEFAULT_SEARCH_KWARGS,
transformation_name=transformation_name,
)


Expand Down
27 changes: 27 additions & 0 deletions langchain_benchmarks/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Schema for the Langchain Benchmarks."""
from __future__ import annotations

import dataclasses
import inspect
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

from langchain.prompts import ChatPromptTemplate
Expand Down Expand Up @@ -104,6 +107,7 @@ def _table(self) -> List[List[str]]:
return table + [
["Retriever Factories", ", ".join(self.retriever_factories.keys())],
["Architecture Factories", ", ".join(self.architecture_factories.keys())],
["get_docs", self.get_docs],
]


Expand Down Expand Up @@ -149,6 +153,29 @@ def _repr_html_(self) -> str:
]
return tabulate(table, headers=headers, tablefmt="html")

def filter(
self,
Type: Optional[str],
dataset_id: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Registry:
"""Filter the tasks in the registry."""
tasks = self.tasks
if Type is not None:
tasks = [task for task in tasks if task.__class__.__name__ == Type]
if dataset_id is not None:
tasks = [task for task in tasks if task.dataset_id == dataset_id]
if name is not None:
tasks = [task for task in tasks if task.name == name]
if description is not None:
tasks = [
task
for task in tasks
if description.lower() in task.description.lower()
]
return Registry(tasks=tasks)

def __getitem__(self, key: Union[int, str]) -> BaseTask:
"""Get an environment from the registry."""
if isinstance(key, slice):
Expand Down

0 comments on commit fd0203c

Please sign in to comment.