Skip to content

Commit

Permalink
TODOs:
Browse files Browse the repository at this point in the history
1. remove parallel - doesn't make sense for GPU case as you can't parallelize that, and you want to use datasets.map() for batching.
2. make it run on datasets
  • Loading branch information
skrawcz committed May 20, 2024
1 parent b76d011 commit f840934
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions examples/LLM_Workflows/NER_Example/ner_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def medium_articles() -> pd.DataFrame:
df = load_dataset(
"fabiochiu/medium-articles", data_files="medium_articles.csv", split="train"
).to_pandas()
# change to HF datasetset
return df


Expand All @@ -25,8 +26,8 @@ def sampled_articles(medium_articles: pd.DataFrame) -> pd.DataFrame:
return df


def device() -> int:
return torch.cuda.current_device() if torch.cuda.is_available() else None
def device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"


def model_id() -> str:
Expand All @@ -48,14 +49,15 @@ def model(model_id: str) -> object:


# load the tokenizer and model into a NER pipeline
def ner_pipeline(model: object, tokenizer: AutoTokenizer, device: int) -> base.Pipeline:
def ner_pipeline(model: object, tokenizer: AutoTokenizer, device: str) -> base.Pipeline:
print("Loading the ner_pipeline")
device_no = torch.cuda.current_device() if device == "cuda" else None
return pipeline(
"ner", model=model, tokenizer=tokenizer, aggregation_strategy="max", device=device
"ner", model=model, tokenizer=tokenizer, aggregation_strategy="max", device=device_no
)


def retriever(device: int) -> SentenceTransformer:
def retriever(device: str) -> SentenceTransformer:
"""A retriever model is used to embed passages (article title + first 1000 characters) and queries. It creates embeddings such that queries and passages with similar meanings are close in the vector space. We will use a sentence-transformer model as our retriever. The model can be loaded as follows:"""
print("Loading the retriever model")
return SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
Expand All @@ -76,6 +78,7 @@ def batch_size() -> int:
def _extract_named_entities(text_batch, ner_pipeline):
# extract named entities using the NER pipeline
extracted_batch = ner_pipeline(text_batch)
# this should be extracted_batch = dataset.map(ner_pipeline)
entities = []
# loop through the results and only select the entity names
for text in extracted_batch:
Expand All @@ -98,6 +101,9 @@ def lancedb_data(
# extract batch
batch = sampled_articles.iloc[i:i_end].copy()
# generate embeddings for batch
# def _title_text_list(batch):
# return retriever.encode(batch["title_text"].tolist()).tolist()
# emb = dataset.map(_title_text_list)
emb = retriever.encode(batch["title_text"].tolist()).tolist()
# extract named entities from the batch
entities = _extract_named_entities(batch["title_text"].tolist(), ner_pipeline)
Expand Down Expand Up @@ -127,7 +133,7 @@ def load_into_lancedb(
) -> int:
try:
db.create_table(table_name, lancedb_data)
except ValueError:
except (OSError, ValueError):
tbl = db.open_table(table_name)
tbl.add(lancedb_data)
return len(lancedb_data)
Expand Down

0 comments on commit f840934

Please sign in to comment.