Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-1362: Add option for hidden state position in FlairEmbeddings #1571

Merged
merged 6 commits into from
May 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,16 +1836,27 @@ def __str__(self):
class FlairEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

def __init__(self, model, fine_tune: bool = False, chars_per_chunk: int = 512):
def __init__(self,
model,
fine_tune: bool = False,
chars_per_chunk: int = 512,
with_whitespace: bool = True,
tokenized_lm: bool = True,
):
"""
initializes contextual string embeddings using a character-level language model.
:param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward'
'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward',
etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md)
depending on which character language model is desired.
:param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows down
training and often leads to overfitting, so use with caution.
:param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster but requires
more memory. Lower means slower but less memory.
:param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows
down training and often leads to overfitting, so use with caution.
:param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster
but requires more memory. Lower means slower but less memory.
:param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden
state at last character of word.
:param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text
False might be better.
"""
super().__init__()

Expand Down Expand Up @@ -2000,6 +2011,8 @@ def __init__(self, model, fine_tune: bool = False, chars_per_chunk: int = 512):
self.static_embeddings = not fine_tune

self.is_forward_lm: bool = self.lm.is_forward_lm
self.with_whitespace: bool = with_whitespace
self.tokenized_lm: bool = tokenized_lm
self.chars_per_chunk: int = chars_per_chunk

# embed a dummy sentence to determine embedding_length
Expand Down Expand Up @@ -2032,13 +2045,20 @@ def embedding_length(self) -> int:

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

# make compatible with serialized models (TODO: remove)
if "with_whitespace" not in self.__dict__:
self.with_whitespace = True
if "tokenized_lm" not in self.__dict__:
self.tokenized_lm = True

# gradients are enable if fine-tuning is enabled
gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()

with gradient_context:

# if this is not possible, use LM to generate embedding. First, get text sentences
text_sentences = [sentence.to_tokenized_string() for sentence in sentences]
text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \
else [sentence.to_plain_string() for sentence in sentences]

start_marker = "\n"
end_marker = " "
Expand All @@ -2053,25 +2073,31 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

# take first or last hidden states from language model as word representation
for i, sentence in enumerate(sentences):
sentence_text = sentence.to_tokenized_string()
sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string()

offset_forward: int = len(start_marker)
offset_backward: int = len(sentence_text) + len(start_marker)

for token in sentence.tokens:

offset_forward += len(token.text)

if self.is_forward_lm:
offset = offset_forward
offset_with_whitespace = offset_forward
offset_without_whitespace = offset_forward - 1
else:
offset = offset_backward
offset_with_whitespace = offset_backward
offset_without_whitespace = offset_backward - 1

embedding = all_hidden_states_in_lm[offset, i, :]
# offset mode that extracts at whitespace after last character
if self.with_whitespace:
embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :]
# offset mode that extracts at last character
else:
embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :]

# if self.tokenized_lm or token.whitespace_after:
offset_forward += 1
offset_backward -= 1
if self.tokenized_lm or token.whitespace_after:
offset_forward += 1
offset_backward -= 1

offset_backward -= len(token.text)

Expand Down