Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lancedb ner example #912

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/LLM_Workflows/NER_Example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Document processing with Named Entity Recognition (NER) for RAG
This example demonstrates how to use a Named Entity Recognition (NER) model to extract entities from text along
with embeddings to facilitate querying with more precision. Specifically we'll use the entities here to filter to
the documents that contain the entities of interest.

In general the concept we're showing here, is that if you extract extra metadata, like the entities text mentions,
this can be used when trying to find the most relevant text to pass to an LLM in a retrieval augmented generation (RAG)
context.

The pipeline we create can be seen in the image below.
![pipeine](ner_extraction_pipeline.png)

To run this in a notebook:

1. Install the requirements by running `pip install -r requirements.txt`.
2. Install `jupyter` by running `pip install jupyter`.
3. Run `jupyter notebook` in the current directory and open `notebook.ipynb`.

Alternatively open this notebook in Google Colab by clicking the button below:

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/LLM_Workflows/NER_Example/notebook.ipynb)

To run this example via the commandline :
1. Install the requirements by running `pip install -r requirements.txt`
2. Run the script `python run.py`. Some example commands:

- python run.py medium_docs load
- python run.py medium_docs query --query "Why does SpaceX want to build a city on Mars?"
- python run.py medium_docs query --query "How are autonomous vehicles changing the world?"

3. To see the full list of commands run `python run.py --help`.
105 changes: 105 additions & 0 deletions examples/LLM_Workflows/NER_Example/lancedb_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Union

import lancedb
import numpy as np
import pyarrow as pa
from datasets import Dataset
from datasets.formatting.formatting import LazyBatch
from sentence_transformers import SentenceTransformer


def db_client() -> lancedb.DBConnection:
"""the lancedb client"""
return lancedb.connect("./.lancedb")


def _write_to_lancedb(
data: Union[list[dict], pa.Table], db: lancedb.DBConnection, table_name: str
) -> int:
"""Helper function to write to lancedb.

This can handle the case the table exists or it doesn't.
"""
try:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
db.create_table(table_name, data)
except (OSError, ValueError):
tbl = db.open_table(table_name)
tbl.add(data)
return len(data)


def _batch_write(dataset_batch: LazyBatch, db, table_name, columns_of_interest) -> None:
"""Helper function to batch write to lancedb."""
# we pull out the pyarrow table and select what we want from it
if columns_of_interest is not None:
_write_to_lancedb(dataset_batch.pa_table.select(columns_of_interest), db, table_name)
else:
_write_to_lancedb(dataset_batch.pa_table, db, table_name)
return None


def loaded_lancedb_table(
final_dataset: Dataset,
db_client: lancedb.DBConnection,
table_name: str,
columns_of_interest: list[str],
write_batch_size: int = 100,
) -> lancedb.table.Table:
"""Loads the data into lancedb explicitly -- but we lose some visibility this way.

This function uses batching to write to lancedb.
"""
final_dataset.map(
_batch_write,
batched=True,
batch_size=write_batch_size,
fn_kwargs={
"db": db_client,
"table_name": table_name,
"columns_of_interest": columns_of_interest,
},
desc="writing to lancedb",
)
return db_client.open_table(table_name)


def lancedb_table(db_client: lancedb.DBConnection, table_name: str = "tw") -> lancedb.table.Table:
"""Table to query against"""
tbl = db_client.open_table(table_name)
return tbl


def lancedb_result(
query: str,
named_entities: list[str],
retriever: SentenceTransformer,
lancedb_table: lancedb.table.Table,
top_k: int = 10,
prefilter: bool = True,
) -> dict:
"""Result of querying lancedb.

:param query: the query
:param named_entities: the named entities found in the query
:param retriever: the model to create the embedding from the query
:param lancedb_table: the lancedb table to query against
:param top_k: number of top results
:param prefilter: whether to prefilter results before cosine distance
:return: dictionary result
"""
# create embeddings for the query
query_vector = np.array(retriever.encode(query).tolist())

# query the lancedb table
query_builder = lancedb_table.search(query_vector, vector_column_name="vector")
if named_entities:
# applying named entity filter if something was returned
where_clause = f"array_length(array_intersect({named_entities}, named_entities)) > 0"
query_builder = query_builder.where(where_clause, prefilter=prefilter)
result = (
query_builder.select(["title", "url", "named_entities"]) # what to return
.limit(top_k)
.to_list()
)
# could rerank results here
return {"Query": query, "Query Entities": named_entities, "Result": result}
176 changes: 176 additions & 0 deletions examples/LLM_Workflows/NER_Example/ner_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Union

import torch
from datasets import Dataset, load_dataset # noqa: F401
from datasets.formatting.formatting import LazyBatch
from sentence_transformers import SentenceTransformer
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
pipeline,
)
from transformers.pipelines import base

from hamilton.function_modifiers import load_from, save_to, source, value

# Could explicitly load the dataset this way
# def medium_articles() -> Dataset:
# """Loads medium dataset into a hugging face dataset"""
# ds = load_dataset(
# "fabiochiu/medium-articles",
# data_files="medium_articles.csv",
# split="train"
# )
# return ds


@load_from.hf_dataset(
path=value("fabiochiu/medium-articles"),
data_files=value("medium_articles.csv"),
split=value("train"),
)
def medium_articles(dataset: Dataset) -> Dataset:
"""Loads medium dataset into a hugging face dataset"""
return dataset


def sampled_articles(
medium_articles: Dataset,
sample_size: int = 104,
random_state: int = 32,
max_text_length: int = 1000,
) -> Dataset:
"""Samples the articles and does some light transformations.
Transformations:
- selects the first 1000 characters of text. This is for performance here. But in real life you'd \
do something for your use case.
- Joins article title and the text to create one text string.
"""
# Filter out entries with NaN values in 'text' or 'title' fields
dataset = medium_articles.filter(
lambda example: example["text"] is not None and example["title"] is not None
)

# Shuffle and take the first 10000 samples
dataset = dataset.shuffle(seed=random_state).select(range(sample_size))

# Truncate the 'text' to the first 1000 characters
dataset = dataset.map(lambda example: {"text": example["text"][:max_text_length]})

# Concatenate the 'title' and truncated 'text'
dataset = dataset.map(lambda example: {"title_text": example["title"] + ". " + example["text"]})
return dataset


def device() -> str:
"""Whether this is a CUDA or CPU enabled device."""
return "cuda" if torch.cuda.is_available() else "cpu"


def NER_model_id() -> str:
"""Model ID to use
To extract named entities, we will use a NER model finetuned on a BERT-base model.
The model can be loaded from the HuggingFace model hub.
Use `overrides={"NER_model_id": VALUE}` to switch this without changing code.
"""
return "dslim/bert-base-NER"


def tokenizer(NER_model_id: str) -> PreTrainedTokenizer:
"""Loads the tokenizer for the NER model ID from huggingface"""
return AutoTokenizer.from_pretrained(NER_model_id)


def model(NER_model_id: str) -> PreTrainedModel:
"""Loads the NER model from huggingface"""
return AutoModelForTokenClassification.from_pretrained(NER_model_id)


def ner_pipeline(
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: str
) -> base.Pipeline:
"""Loads the tokenizer and model into a NER pipeline. That is it combines them."""
device_no = torch.cuda.current_device() if device == "cuda" else None
return pipeline(
"ner", model=model, tokenizer=tokenizer, aggregation_strategy="max", device=device_no
)


def retriever(
device: str, retriever_model_id: str = "flax-sentence-embeddings/all_datasets_v3_mpnet-base"
) -> SentenceTransformer:
"""Our retriever model to create embeddings.

A retriever model is used to embed passages (article title + first 1000 characters)
and queries. It creates embeddings such that queries and passages with similar
meanings are close in the vector space. We will use a sentence-transformer model
as our retriever. The model can be loaded as follows:
"""
return SentenceTransformer(retriever_model_id, device=device)


def _extract_named_entities_text(
title_text_batch: Union[LazyBatch, list[str]], _ner_pipeline
) -> list[list[str]]:
"""Helper function to extract named entities given a batch of text."""
# extract named entities using the NER pipeline
extracted_batch = _ner_pipeline(title_text_batch)
# this should be extracted_batch = dataset.map(ner_pipeline)
entities = []
# loop through the results and only select the entity names
for text in extracted_batch:
ne = [entity["word"] for entity in text]
entities.append(ne)
_named_entities = [list(set(entity)) for entity in entities]
return _named_entities


def _batch_map(dataset: LazyBatch, _retriever, _ner_pipeline) -> dict:
"""Helper function to created the embedding vectors and extract named entities"""
title_text_list = dataset["title_text"]
emb = _retriever.encode(title_text_list)
_named_entities = _extract_named_entities_text(title_text_list, _ner_pipeline)
return {
"vector": emb,
"named_entities": _named_entities,
}


def columns_of_interest() -> list[str]:
"""The columns we expect to pull from the dataset to be saved to lancedb"""
return ["vector", "named_entities", "title", "url", "authors", "timestamp", "tags"]


@save_to.lancedb(
db_client=source("db_client"),
table_name=source("table_name"),
columns_to_write=source("columns_of_interest"),
output_name_="load_into_lancedb",
)
def final_dataset(
sampled_articles: Dataset,
retriever: SentenceTransformer,
ner_pipeline: base.Pipeline,
) -> Dataset:
"""The final dataset to be pushed to lancedb.

This adds two columns:

- vector -- the vector embedding
- named_entities -- the names of entities extracted from the text
"""
# goes over the data in batches so that the GPU can be properly utilized.
final_ds = sampled_articles.map(
_batch_map,
batched=True,
fn_kwargs={"_retriever": retriever, "_ner_pipeline": ner_pipeline},
desc="extracting entities",
)
return final_ds


def named_entities(query: str, ner_pipeline: base.Pipeline) -> list[str]:
"""The entities to extract from the query via the pipeline."""
return _extract_named_entities_text([query], ner_pipeline)[0]
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading