Skip to content

Commit

Permalink
Use activations stored by the senter pipe
Browse files Browse the repository at this point in the history
Before this change, we'd use the senter pipe directly. However, this did
not work with the transformer model without modifications (because it
clears tensors after backprop). By using the functionality proposed in

explosion/spaCy#11002

we can use the activations that are stored by the senter pipe in `Doc`.
  • Loading branch information
danieldk committed Jun 28, 2022
1 parent 97f6186 commit 04d2931
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 9 deletions.
2 changes: 2 additions & 0 deletions projects/biaffine_parser/configs/base-config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pooling = {"@layers":"reduce_mean.v1"}

[components.arc_predicter]
factory = "experimental_arc_predicter"
senter = "senter"

[components.arc_predicter.model]
@architectures = "spacy-experimental.PairwiseBilinear.v1"
Expand Down Expand Up @@ -78,6 +79,7 @@ pooling = {"@layers":"reduce_mean.v1"}

[components.senter]
factory = "senter"
store_activations = true

[components.senter.model]
@architectures = "spacy.Tagger.v1"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ requires = [
"setuptools",
"wheel",
"Cython<3.0",
"spacy>=3.3.0,<3.4.0",
"spacy>=3.4.0,<3.5.0",
"numpy",
]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
spacy>=3.3.0,<3.4.0
spacy>=3.4.0,<3.5.0

# Development dependencies
cython>=0.25,<3.0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ long_description_content_type = text/markdown
zip_safe = false
python_requires = >=3.6
install_requires =
spacy>=3.3.0,<3.4.0
spacy>=3.4.0,<3.5.0

[options.entry_points]
spacy_architectures =
Expand Down
14 changes: 8 additions & 6 deletions spacy_experimental/biaffine_parser/arc_predicter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class ArcPredicter(TrainablePipe):
docs = list(docs)

if self.senter:
lengths = split_greedily(docs, ops=self.model.ops, max_length=self.max_length, senter=self.senter, is_train=False)
lengths = split_lazily(docs, ops=self.model.ops, max_length=self.max_length, senter=self.senter, is_train=False)
else:
lengths = sents2lens(docs, ops=self.model.ops)
scores = self.model.predict((docs, lengths))
Expand Down Expand Up @@ -224,7 +224,7 @@ class ArcPredicter(TrainablePipe):
docs = [eg.predicted for eg in examples]

if self.senter:
lens = split_greedily(docs, ops=self.model.ops, max_length=self.max_length, senter=self.senter, is_train=True)
lens = split_lazily(docs, ops=self.model.ops, max_length=self.max_length, senter=self.senter, is_train=True)
else:
lens = sents2lens(docs, ops=self.model.ops)
if lens.sum() == 0:
Expand Down Expand Up @@ -316,11 +316,13 @@ def sents2lens(docs: List[Doc], *, ops: Ops) -> Ints1d:

return ops.asarray1i(lens)

def split_greedily(docs: List[Doc], *, ops: Ops, max_length: int, senter: SentenceRecognizer, is_train: bool):
split_predictions, _ = senter.model(docs, is_train)

def split_lazily(docs: List[Doc], *, ops: Ops, max_length: int, senter: SentenceRecognizer, is_train: bool) -> Ints1d:
lens = []
for (doc, scores) in zip(docs, split_predictions):
for doc in docs:
activations = doc.activations.get(senter.name, None)
if activations is None:
raise ValueError("Greedy splitting requires senter with `store_activations` enabled.")
scores = activations['probs']
split_recursive(scores[:,1], ops, max_length, lens)

assert sum(lens) == sum([len(doc) for doc in docs])
Expand Down

0 comments on commit 04d2931

Please sign in to comment.