-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ✨ Added KnowledgeExtractor. Added KnowGL Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Created tests for KnowledgeExtractor Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 📝 Added documentation of Knowledge Extractor component Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ♻️ Created mappings module from spans_to_ functions of Regen Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 🐛 Fix bug in labels in KnowGL Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 📝 Updated README Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 🎨 Fix flake Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 📝 Change test name Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 🐛 Fix bug in relations renderer (#65) Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * Fix/tests (#66) * ✅🐛 Improve tests performance. Fix minor bugs related Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Added download models Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 🎨 Fixed flake Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * 🎨 Fixed flake Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * Added try/except on load models Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * Added try/except on load models Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Added xfail to tars tests Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Updated pydantic requirements Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Added fewrel tests Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * ✅ Updated tests Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> * Revert "✅ Updated tests" This reverts commit eae7c8c. Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> --------- Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com> --------- Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>
- Loading branch information
Showing
19 changed files
with
657 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# KnowGL Knowledge Extractor | ||
|
||
The knowgl-large model is trained by combining Wikidata with an extended version of the training data in the REBEL dataset. Given a sentence, KnowGL generates triple(s) in the following format: | ||
``` | ||
[(subject mention # subject label # subject type) | relation label | (object mention # object label # object type)] | ||
``` | ||
If there are more than one triples generated, they are separated by $ in the output. The model achieves state-of-the-art results for relation extraction on the REBEL dataset. The generated labels (for the subject, relation, and object) and their types can be directly mapped to Wikidata IDs associated with them. | ||
|
||
This `KnowledgeExtractor` does not use any entity/relation pre-defined. | ||
|
||
- [Paper Rossiello et al. (AAAI 2023)](https://arxiv.org/pdf/2210.13952.pdf) | ||
- [Paper Mihindukulasooriya et al. (ISWC 2022)](https://arxiv.org/pdf/2207.05188.pdf) | ||
- [Original Model](https://huggingface.co/ibm/knowgl-large) | ||
|
||
::: zshot.knowledge_extractor.KnowGL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Knowledge Extractor | ||
|
||
The **knowledge extractor** will perform at the same time the extraction and classification of named entities and the extraction of relations among them. | ||
|
||
Currently, the is only one Knowledge Extractor available: KnowGL | ||
|
||
|
||
::: zshot.KnowledgeExtractor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor # noqa: F401 | ||
from zshot.knowledge_extractor.knowgl.knowledge_extractor_knowgl import KnowGL # noqa: F401 |
Empty file.
77 changes: 77 additions & 0 deletions
77
zshot/knowledge_extractor/knowgl/knowledge_extractor_knowgl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import List, Tuple, Iterator, Optional, Union | ||
|
||
from spacy.tokens import Doc | ||
from tokenizers import Encoding | ||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | ||
|
||
from zshot.knowledge_extractor.knowgl.utils import get_words_mappings, get_spans, get_triples | ||
from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor | ||
from zshot.utils.data_models import Span | ||
from zshot.utils.data_models.relation_span import RelationSpan | ||
|
||
|
||
class KnowGL(KnowledgeExtractor): | ||
def __init__(self, model_name="ibm/knowgl-large"): | ||
""" Instantiate the KnowGL Knowledge Extractor """ | ||
super().__init__() | ||
|
||
self.model_name = model_name | ||
self.model = None | ||
self.tokenizer = None | ||
|
||
def load_models(self): | ||
""" Load KnowGL model """ | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | ||
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | ||
self.model.to(self.device) | ||
|
||
def parse_result(self, result: str, doc: Doc, | ||
encodings: Encoding) -> List[Tuple[Span, RelationSpan, Span]]: | ||
""" Parse the text result into a list of triples | ||
:param result: Text generate by the KnowGL model | ||
:param doc: Spacy doc | ||
:param encodings: Encodings result of the tokenization | ||
:return: List of triples (subject, relation, object) | ||
""" | ||
words_mapping, char_mapping = get_words_mappings(encodings, doc.text) | ||
triples = [] | ||
for triple in result.split("$"): | ||
subject_, relation, object_ = triple.split("|") | ||
s_mention, s_label, s_type = subject_.strip("[()]").split("#") | ||
o_mention, o_label, o_type = object_.strip("[()]").split("#") | ||
s_type = s_label if s_label != "None" else s_type | ||
o_type = o_label if o_label != "None" else o_type | ||
subject_spans = get_spans(s_mention, s_type, self.tokenizer, encodings, | ||
words_mapping, char_mapping) | ||
object_spans = get_spans(o_mention, o_type, self.tokenizer, encodings, | ||
words_mapping, char_mapping) | ||
triples += get_triples(subject_spans, relation, object_spans) | ||
|
||
return triples | ||
|
||
def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ | ||
-> List[List[Tuple[Span, RelationSpan, Span]]]: | ||
""" Extract triples from docs | ||
:param docs: Spacy Docs to process | ||
:param batch_size: Batch size for processing | ||
:return: Triples (subject, relation, object) extracted for each document | ||
""" | ||
if not self.model: | ||
self.load_models() | ||
|
||
texts = [d.text for d in docs] | ||
input_data = self.tokenizer(texts, | ||
truncation=True, | ||
padding=True, | ||
return_tensors="pt") | ||
input_ids = input_data.input_ids.to(self.model.device) | ||
outputs = self.model.generate(inputs=input_ids) | ||
|
||
triples = [] | ||
for doc, output, encodings in zip(docs, outputs, input_data.encodings): | ||
result = self.tokenizer.decode(token_ids=output, skip_special_tokens=True) | ||
triples.append(self.parse_result(result, doc, encodings)) | ||
|
||
return triples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from itertools import groupby, product | ||
from typing import List, Any, Dict, Tuple | ||
|
||
from tokenizers import Tokenizer, Encoding | ||
|
||
from zshot.utils.data_models import Span, Relation | ||
from zshot.utils.data_models.relation_span import RelationSpan | ||
|
||
|
||
def ranges(lst: List[int]) -> List[List[int]]: | ||
""" Get groups made by consecutive numbers in the given list | ||
:param lst: List to get groups from | ||
:return: Groups of consecutive numbers | ||
""" | ||
pos = (j - i for i, j in enumerate(lst)) | ||
t = 0 | ||
groups = [] | ||
for i, els in groupby(pos): | ||
lst_ = len(list(els)) | ||
el = lst[t] | ||
t += lst_ | ||
groups.append(list(range(el, el + lst_))) | ||
|
||
return groups | ||
|
||
|
||
def find_sub_list(sl: List[Any], lst: List[Any]): | ||
""" Return init and end indexes of a sublist in a list | ||
:param sl: Sublist | ||
:param lst: List | ||
:return: List of tuples with the init and the end indexes | ||
""" | ||
results = [] | ||
sll = len(sl) | ||
for ind in (i for i, e in enumerate(lst) if e == sl[0]): | ||
if lst[ind:ind + sll] == sl: | ||
results.append((ind, ind + sll - 1)) | ||
|
||
return results | ||
|
||
|
||
def get_spans(mention: str, label: str, | ||
tokenizer: Tokenizer, encodings: Encoding, | ||
words_mapping: Dict[int, str], char_mapping: Dict[int, List[int]]) -> List[Span]: | ||
""" Get spans from a mention | ||
:param mention: Mention text to get Spans from | ||
:param label: Label to assign to the Spans | ||
:param tokenizer: Tokenizer used for tokenization | ||
:param encodings: Encodings result of the tokenization | ||
:param words_mapping: Mapping from words indexes to words | ||
:param char_mapping: Mapping from words indexes to char init/end indexes | ||
:return: List of Spans | ||
""" | ||
spans = [] | ||
|
||
# Find tokens in the list of encodings to create the spans | ||
tokens = tokenizer.encode(mention) | ||
results = find_sub_list(tokens[1:-1], encodings.ids) | ||
for result in results: | ||
init = encodings.token_to_chars(result[0])[0] | ||
end = encodings.token_to_chars(result[1])[-1] | ||
spans.append(Span(init, end, label)) | ||
|
||
# With some tokenizers, the result might be different depending on the surroundings of the mention | ||
# In this case, get the words indexes to get the span char limits | ||
if not spans: | ||
words = mention.lower().split() | ||
words_idxs = [k for k, v in words_mapping.items() if v.lower() in words] | ||
valid_groups = [group for group in ranges(words_idxs) if len(group) == len(words)] | ||
for group in valid_groups: | ||
init = char_mapping[group[0]][0] | ||
end = char_mapping[group[-1]][-1] | ||
spans.append(Span(init, end, label)) | ||
return spans | ||
|
||
|
||
def get_words_mappings(encodings: Encoding, text: str) -> Tuple[Dict[int, str], Dict[int, List[int]]]: | ||
""" Get words mappings from word index to word string and char span | ||
:param encodings: Encodings result of tokenization | ||
:param text: Text to get words mappings from | ||
:return: Mapping from words indexes to words and Mapping from words indexes to char init/end indexes | ||
""" | ||
words_mapping = {} | ||
char_mapping = {} | ||
for token_idx in range(len(encodings.ids)): | ||
try: | ||
init, end = encodings.token_to_chars(token_idx) | ||
word_idx: int = encodings.token_to_word(token_idx) | ||
if word_idx not in words_mapping: | ||
words_mapping[word_idx] = text[init:end].lower() | ||
char_mapping[word_idx] = [init, end] | ||
else: | ||
words_mapping[word_idx] += text[init:end].lower() | ||
char_mapping[word_idx][1] = end | ||
except TypeError: | ||
pass | ||
|
||
return words_mapping, char_mapping | ||
|
||
|
||
def get_triples(subject_spans: List[Span], relation: str, object_spans: List[Span]) \ | ||
-> List[Tuple[Span, RelationSpan, Span]]: | ||
""" Get all possible triples from the spans | ||
:param subject_spans: List of spans for the subject | ||
:param relation: Relation name | ||
:param object_spans: List of spans for the object | ||
:return: List of triples (subject, relation, object) | ||
""" | ||
triples = [] | ||
relation = Relation(name=relation, description="") | ||
# As one word might be repeated in the text, we have to relate all of them | ||
for comb in product(subject_spans, object_spans): | ||
triples.append((comb[0], | ||
RelationSpan(start=comb[0], end=comb[1], relation=relation), | ||
comb[1])) | ||
|
||
return triples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import pickle as pkl | ||
from abc import ABC, abstractmethod | ||
from typing import List, Iterator, Optional, Union, Tuple | ||
|
||
import torch | ||
import zlib | ||
from spacy.tokens import Doc | ||
|
||
from zshot.utils.alignment_utils import filter_overlapping_spans, spacy_token_offsets | ||
from zshot.utils.data_models import Span | ||
from zshot.utils.data_models.relation_span import RelationSpan | ||
|
||
|
||
class KnowledgeExtractor(ABC): | ||
|
||
def __init__(self, device: Optional[Union[str, torch.device]] = None): | ||
""" Instantiate the Knowledge Extractor | ||
:param device: Device to be used for computation | ||
""" | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | ||
|
||
def set_device(self, device: Union[str, torch.device]): | ||
""" | ||
Set the device to use | ||
:param device: | ||
:return: | ||
""" | ||
self.device = device | ||
|
||
def load_models(self): | ||
""" | ||
Load the model | ||
:return: | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ | ||
-> List[List[Tuple[Span, RelationSpan, Span]]]: | ||
""" | ||
Perform the knowledge extraction. | ||
:param docs: A list of spacy Document | ||
:param batch_size: The batch size | ||
:return: the predicted triples | ||
""" | ||
pass | ||
|
||
def parse_triples(self, preds: List[Tuple[Span, RelationSpan, Span]]) -> Tuple[List[Span], List[RelationSpan]]: | ||
""" Parse the triples into lists of entities and relations | ||
:param preds: Predicted triples | ||
:return: Tuple with list of entities and list of relations | ||
""" | ||
entities = [] | ||
relations = [] | ||
for triple in preds: | ||
entities.append(triple[0]) | ||
entities.append(triple[2]) | ||
relations.append(triple[1]) | ||
|
||
return list(set(entities)), list(set(relations)) | ||
|
||
def extract_knowledge(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None): | ||
""" | ||
Perform the relations extraction. Call the predict function and add the mentions to the Spacy Doc | ||
:param docs: A list of spacy Document | ||
:param batch_size: The batch size | ||
:return: | ||
""" | ||
predicted_triples = self.predict(docs, batch_size) | ||
for d, preds in zip(docs, predicted_triples): | ||
entities, relations = self.parse_triples(preds) | ||
d._.relations = relations | ||
d._.spans = entities | ||
d.ents = map(lambda p: p.to_spacy_span(d), filter_overlapping_spans(entities, list(d), | ||
tokens_offsets=spacy_token_offsets(d))) | ||
|
||
@staticmethod | ||
def version() -> str: | ||
return "v1" | ||
|
||
@staticmethod | ||
def _get_serialize_file(path): | ||
""" Get full filepath of the serialization file """ | ||
return os.path.join(path, "knowledge_extractor.pkl") | ||
|
||
@classmethod | ||
def from_disk(cls, path, exclude=()): | ||
""" Load component from disk """ | ||
serialize_file = cls._get_serialize_file(path) | ||
with open(serialize_file, "rb") as f: | ||
return pkl.load(f) | ||
|
||
def to_disk(self, path): | ||
""" Save component into disk """ | ||
serialize_file = self._get_serialize_file(path) | ||
with open(serialize_file, "wb") as f: | ||
return pkl.dump(self, f) | ||
|
||
def __hash__(self): | ||
""" Get hash representation of the component """ | ||
self_repr = f"{self.__class__.__name__}.{self.version()}.{str(self.__dict__)}" | ||
return zlib.crc32(self_repr.encode()) |
Oops, something went wrong.