Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlairNLP Sequence Tagging #55

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ following table provides an overview about them:
| SpacyPosClassifier | Part-of-speech prediction with [spaCy](https://spacy.io/) | no |
| AdapterSequenceTagger | Sequence tagger using [Adapters](https://adapterhub.ml/) | no |
| AdapterSentenceClassifier | Sentence classifier using [Adapters](https://adapterhub.ml/) | no |
| FlairNERClassifier | Sequence tagger using [Flair](https://flairnlp.github.io/) | no |

For using trainable recommenders it is important to check the checkbox *Trainable* when adding
the external recommender to your project. To be able to get predictions of a added trainable
Expand Down
75 changes: 75 additions & 0 deletions ariadne/contrib/flair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Technische Universität Darmstadt under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The Technische Universität Darmstadt
# licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

Check warning on line 16 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L16

Added line #L16 was not covered by tests

from cassis import Cas

Check warning on line 18 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L18

Added line #L18 was not covered by tests

from flair.nn import Classifier as Tagger
from flair.data import Sentence

Check warning on line 21 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L20-L21

Added lines #L20 - L21 were not covered by tests

from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import create_prediction, SENTENCE_TYPE, TOKEN_TYPE

Check warning on line 24 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L23-L24

Added lines #L23 - L24 were not covered by tests


class FlairNERClassifier(Classifier):
def __init__(self, model_name: str, model_directory: Path = None, split_sentences: bool = True):
super().__init__(model_directory=model_directory)
self._model = Tagger.load(model_name)
self._split_sentences = split_sentences

Check warning on line 31 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L27-L31

Added lines #L27 - L31 were not covered by tests

def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):

Check warning on line 33 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L33

Added line #L33 was not covered by tests
# Extract the sentences from the CAS
if self._split_sentences:
cas_sents = cas.select(SENTENCE_TYPE)
sents = [Sentence(sent.get_covered_text(), use_tokenizer=False) for sent in cas_sents]
offsets = [sent.begin for sent in cas_sents]

Check warning on line 38 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L35-L38

Added lines #L35 - L38 were not covered by tests

# Find the named entities
self._model.predict(sents)

Check warning on line 41 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L41

Added line #L41 was not covered by tests

for offset, sent in zip(offsets, sents):

Check warning on line 43 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L43

Added line #L43 was not covered by tests
# For every entity returned by spacy, create an annotation in the CAS
for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"] + offset
end = named_entity["end_pos"] + offset
label = named_entity["labels"][0]["value"]
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)

Check warning on line 50 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L45-L50

Added lines #L45 - L50 were not covered by tests

else:
cas_tokens = cas.select(TOKEN_TYPE)

Check warning on line 53 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L53

Added line #L53 was not covered by tests

# build sentence with correct whitespaces
# (when using sentences, this should not be a problem afaik)
text = ""
last_end = 0
for cas_token in cas_tokens:
if cas_token.begin == last_end:
text += cas_token.get_covered_text()

Check warning on line 61 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L57-L61

Added lines #L57 - L61 were not covered by tests
else:
text += " " + cas_token.get_covered_text()
last_end = cas_token.end

Check warning on line 64 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L63-L64

Added lines #L63 - L64 were not covered by tests

sent = Sentence(text, use_tokenizer=False)

Check warning on line 66 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L66

Added line #L66 was not covered by tests

self._model.predict(sent)

Check warning on line 68 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L68

Added line #L68 was not covered by tests

for named_entity in sent.to_dict()["entities"]:
begin = named_entity["start_pos"]
end = named_entity["end_pos"]
label = named_entity["labels"][0]["value"]
prediction = create_prediction(cas, layer, feature, begin, end, label)
cas.add(prediction)

Check warning on line 75 in ariadne/contrib/flair.py

View check run for this annotation

Codecov / codecov/patch

ariadne/contrib/flair.py#L70-L75

Added lines #L70 - L75 were not covered by tests
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
HOMEPAGE = "https://inception-project.github.io/"
EMAIL = "inception-users@googlegroups.com"
AUTHOR = "The INCEpTION team"
REQUIRES_PYTHON = ">=3.6.0"
REQUIRES_PYTHON = ">=3.8.0"

install_requires = [
"flask",
"filelock",
"dkpro-cassis>=0.7.6",
"dkpro-cassis>=0.9.1",
"joblib",
"gunicorn",
"deprecation",
Expand All @@ -50,7 +50,8 @@
"sentence-transformers~=2.2.2",
"lightgbm~=4.2.0",
"diskcache~=5.2.1",
"simalign~=0.4"
"simalign~=0.4",
"flair>=0.13.1"
]

test_dependencies = [
Expand Down
Loading