Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store activations in Docs when save_activations is enabled #11002

Merged
merged 42 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b71c604
Store activations in Doc when `store_activations` is enabled
danieldk Jun 22, 2022
c3da32b
Change type of `store_activations` to `Union[bool, List[str]]`
danieldk Jun 22, 2022
789a447
Formatting fixes in Tagger
danieldk Jun 22, 2022
acf47e8
Support store_activations in spancat and morphologizer
danieldk Jun 22, 2022
42526a6
Make Doc.activations type visible to MyPy
danieldk Jun 22, 2022
8772b9c
textcat/textcat_multilabel: add store_activations option
danieldk Jun 22, 2022
1c9be0d
trainable_lemmatizer/entity_linker: add store_activations option
danieldk Jun 23, 2022
009a960
parser/ner: do not currently support returning activations
danieldk Jun 23, 2022
c8a12c5
Extend tagger and senter tests
danieldk Jun 24, 2022
508b96f
Merge remote-tracking branch 'upstream/master' into store-activations
danieldk Jun 27, 2022
3b13f17
Document `Doc.activations` and `store_activations` in the relevant pipes
danieldk Jun 27, 2022
df513be
Merge remote-tracking branch 'upstream/v4' into store-activations
danieldk Jun 27, 2022
5eeb2e8
Start errors/warnings at higher numbers to avoid merge conflicts
danieldk Jul 6, 2022
403b1f1
Add `store_activations` to docstrings.
danieldk Jul 6, 2022
288d27e
Merge remote-tracking branch 'upstream/v4' into store-activations
danieldk Aug 1, 2022
51f72e4
Replace store_activations setter by set_store_activations method
danieldk Aug 4, 2022
6e7b958
Use dict comprehension suggested by @svlandeg
danieldk Aug 4, 2022
230264d
Revert "Use dict comprehension suggested by @svlandeg"
danieldk Aug 5, 2022
57caae8
EntityLinker: add type annotations to _add_activations
danieldk Aug 5, 2022
ce36f34
_store_activations: make kwarg-only, remove doc_scores_lens arg
danieldk Aug 5, 2022
75d76cb
set_annotations: add type annotations
danieldk Aug 5, 2022
c792019
Apply suggestions from code review
danieldk Aug 5, 2022
1cfbb93
TextCat.predict: return dict
danieldk Aug 5, 2022
aea5337
Make the `TrainablePipe.store_activations` property a bool
danieldk Aug 29, 2022
8c2652d
Remove `TrainablePipe.activations`
danieldk Aug 29, 2022
8f84e6e
Add type annotations for activations in predict/set_annotations
danieldk Aug 30, 2022
3937abd
Merge remote-tracking branch 'upstream/v4' into store-activations
danieldk Aug 30, 2022
2290a04
Rename `TrainablePipe.store_activations` to `save_activations`
danieldk Aug 30, 2022
cdbda0b
Error E1400 is not used anymore
danieldk Aug 30, 2022
51c87c5
Change wording in API docs after store -> save change
danieldk Aug 30, 2022
d245da0
docs: tag (save_)activations as new in spaCy 4.0
danieldk Aug 30, 2022
699a187
Fix copied line in morphologizer activations test
danieldk Aug 30, 2022
6f80e80
Don't train in any test_save_activations test
danieldk Aug 30, 2022
cd6e4fa
Rename activations
danieldk Aug 31, 2022
2593dad
Remove unused W400 warning.
danieldk Aug 31, 2022
ac5b1fd
Formatting fixes
danieldk Sep 7, 2022
b828954
Replace "kb_ids" by a constant
danieldk Sep 8, 2022
0bd5730
spancat: replace a cast by an assertion
danieldk Sep 8, 2022
8aba93e
Fix EOF spacing
danieldk Sep 8, 2022
a32dd70
Fix comments in test_save_activations tests
danieldk Sep 8, 2022
3a16b7b
Do not set RNG seed in activation saving tests
danieldk Sep 8, 2022
fc78ed4
Revert "spancat: replace a cast by an assertion"
danieldk Sep 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class Warnings(metaclass=ErrorsWithCodes):
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.")
W400 = ("Activation '{activation}' is unknown for pipe '{pipe_name}'")


class Errors(metaclass=ErrorsWithCodes):
Expand Down Expand Up @@ -939,6 +940,7 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1400 = ("store_activations attribute must be set to List[str] or bool")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
33 changes: 27 additions & 6 deletions spacy/pipeline/edit_tree_lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d, Ints2d
from thinc.types import ArrayXd, Floats2d, Ints1d

from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree
Expand All @@ -21,6 +21,9 @@
from .. import util


ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
svlandeg marked this conversation as resolved.
Show resolved Hide resolved


default_model_config = """
[model]
@architectures = "spacy.Tagger.v2"
Expand Down Expand Up @@ -49,6 +52,7 @@
"overwrite": False,
"top_k": 1,
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
"store_activations": False,
},
default_score_weights={"lemma_acc": 1.0},
)
Expand All @@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
overwrite: bool,
top_k: int,
scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
):
"""Construct an EditTreeLemmatizer component."""
return EditTreeLemmatizer(
Expand All @@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
overwrite=overwrite,
top_k=top_k,
scorer=scorer,
store_activations=store_activations,
)


Expand All @@ -91,6 +97,7 @@ def __init__(
overwrite: bool = False,
top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score,
store_activations=False,
danieldk marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Construct an edit tree lemmatizer.
Expand All @@ -102,6 +109,8 @@ def __init__(
frequency in the training data.
overwrite (bool): overwrite existing lemma annotations.
top_k (int): try to apply at most the k most probable edit trees.
store_activations (Union[bool, List[str]]): Model activations to store in
Doc when annotating. supported activations are: "probs" and "guesses".
"""
self.vocab = vocab
self.model = model
Expand All @@ -116,6 +125,7 @@ def __init__(

self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer
self.store_activations = store_activations # type: ignore
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d]
Expand Down Expand Up @@ -144,21 +154,24 @@ def get_loss(

return float(loss), d_scores

def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs))
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
n_labels = len(self.cfg["labels"])
guesses: List[Ints2d] = [
guesses: List[Ints1d] = [
self.model.ops.alloc((0,), dtype="i") for doc in docs
]
scores: List[Floats2d] = [
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
]
assert len(guesses) == n_docs
return guesses
return {"probs": scores, "guesses": guesses}
scores = self.model.predict(docs)
assert len(scores) == n_docs
guesses = self._scores2guesses(docs, scores)
assert len(guesses) == n_docs
return guesses
return {"probs": scores, "guesses": guesses}

def _scores2guesses(self, docs, scores):
guesses = []
Expand Down Expand Up @@ -186,8 +199,12 @@ def _scores2guesses(self, docs, scores):

return guesses

def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
doc.activations[self.name][activation] = activations[activation][i]
danieldk marked this conversation as resolved.
Show resolved Hide resolved
doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get()
Expand Down Expand Up @@ -377,3 +394,7 @@ def _pair2label(self, form, lemma, add_label=False):
self.tree2label[tree_id] = len(self.cfg["labels"])
self.cfg["labels"].append(tree_id)
return self.tree2label[tree_id]

@property
def activations(self):
return ["probs", "guesses"]
danieldk marked this conversation as resolved.
Show resolved Hide resolved
92 changes: 83 additions & 9 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
from thinc.types import Floats2d
from typing import cast
from numpy import dtype
from thinc.types import Floats2d, Ragged
from pathlib import Path
from itertools import islice
import srsly
Expand All @@ -21,6 +23,9 @@
from .. import util
from ..scorer import Scorer


ActivationsT = Dict[str, Union[List[Ragged], List[str]]]

# See #9050
BACKWARD_OVERWRITE = True

Expand Down Expand Up @@ -57,6 +62,7 @@
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
"threshold": None,
"store_activations": False,
},
default_score_weights={
"nel_micro_f": 1.0,
Expand All @@ -79,6 +85,7 @@ def make_entity_linker(
scorer: Optional[Callable],
use_gold_ents: bool,
threshold: Optional[float] = None,
store_activations: Union[bool, List[str]],
):
"""Construct an EntityLinker component.

Expand All @@ -97,6 +104,8 @@ def make_entity_linker(
component must provide entity annotations.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold.
store_activations (Union[bool, List[str]]): Model activations to store in
Doc when annotating. supported activations are: "ents" and "scores".
"""

if not model.attrs.get("include_span_maker", False):
Expand Down Expand Up @@ -128,6 +137,7 @@ def make_entity_linker(
scorer=scorer,
use_gold_ents=use_gold_ents,
threshold=threshold,
store_activations=store_activations,
)


Expand Down Expand Up @@ -164,6 +174,7 @@ def __init__(
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
threshold: Optional[float] = None,
store_activations=False,
) -> None:
"""Initialize an entity linker.

Expand Down Expand Up @@ -212,6 +223,7 @@ def __init__(
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.threshold = threshold
self.store_activations = store_activations

def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will
Expand Down Expand Up @@ -397,7 +409,7 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
loss = loss / len(entity_encodings)
return float(loss), out

def predict(self, docs: Iterable[Doc]) -> List[str]:
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them.
Returns the KB IDs for each entity in each doc, including NIL if there is
no prediction.
Expand All @@ -410,13 +422,21 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
self.validate_kb()
entity_count = 0
final_kb_ids: List[str] = []
xp = self.model.ops.xp
ops = self.model.ops
xp = ops.xp
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs:
return final_kb_ids
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
if isinstance(docs, Doc):
docs = [docs]
for i, doc in enumerate(docs):
for doc in docs:
doc_ents = []
doc_scores = []
doc_scores_lens: List[int] = []
danieldk marked this conversation as resolved.
Show resolved Hide resolved
if len(doc) == 0:
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue
sentences = [s for s in doc.sents]
# Looping through each entity (TODO: rewrite)
Expand All @@ -439,14 +459,23 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
)
else:
candidates = list(self.get_candidates(self.kb, ent))
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
)
else:
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
Expand Down Expand Up @@ -479,27 +508,45 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
if self.threshold is None or scores.max() >= self.threshold
else EntityLinker.NIL
)
self._add_activations(
doc_scores,
doc_scores_lens,
doc_ents,
scores,
[c.entity for c in candidates],
)
self._add_doc_activations(
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
)
if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format(
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return final_kb_ids
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}

def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores.

docs (Iterable[Doc]): The documents to modify.
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
activations (List[str]): The activations used for setting annotations, produced
by EntityLinker.predict.
danieldk marked this conversation as resolved.
Show resolved Hide resolved

DOCS: https://spacy.io/api/entitylinker#set_annotations
"""
kb_ids = cast(List[str], activations["kb_ids"])
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i = 0
overwrite = self.cfg["overwrite"]
for doc in docs:
for j, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
# We only copy activations that are Ragged.
doc.activations[self.name][activation] = cast(
Ragged, activations[activation][j]
)
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
Expand Down Expand Up @@ -598,3 +645,30 @@ def rehearse(self, examples, *, sgd=None, losses=None, **config):

def add_label(self, label):
raise NotImplementedError

@property
def activations(self):
return ["ents", "scores"]

def _add_doc_activations(
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
):
if len(self.store_activations) == 0:
return
ops = self.model.ops
docs_scores.append(
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
)
docs_ents.append(
Ragged(
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
)
)

def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
if len(self.store_activations) == 0:
return
ops = self.model.ops
doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.xp.array(ents, dtype="uint64"))
Loading