Skip to content

Commit

Permalink
Make coref entry points work without PyTorch (#23)
Browse files Browse the repository at this point in the history
* Make coref entry points work without torch

Before this PR, in environments without PyTorch, using spacy
experimental can fail due to attempts to load entry points. This change
makes it so the types required for class definitions (torch.nn.Module
and torch.Tensor) are stubbed to object when torch is not available.

* Add explanatory comment

* Use has_torch instead of looking for AttributeError

* Add clear errors when attempting to use coref without torch

Without this, it could be unclear why coref didn't work without torch.

* Move Pytorch implementations to separate files

This follows the model of the biaffine parser.

* Fix model name

* Run tests with and without PyTorch

This is the same as the changes from #24, since they worked.

* Remove unused imports in coref
  • Loading branch information
polm authored Sep 28, 2022
1 parent 89e129e commit fb3f24d
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 385 deletions.
6 changes: 5 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ jobs:

- script: |
pip install -r requirements.txt
python -m pytest --pyargs spacy_experimental
displayName: 'Run tests without PyTorch'
- script: |
pip install "torch==1.10.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
python -m pytest --pyargs spacy_experimental
displayName: 'Run tests'
displayName: 'Run tests with PyTorch'
2 changes: 1 addition & 1 deletion spacy_experimental/coref/coref_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .coref_util import create_gold_scores, MentionClusters, create_head_span_idxs
from .coref_util import get_clusters_from_doc, get_predicted_clusters
from .coref_util import DEFAULT_CLUSTER_PREFIX, matches_coref_prefix
from .coref_util import DEFAULT_CLUSTER_PREFIX

from .coref_scorer import score_coref_clusters

Expand Down
280 changes: 8 additions & 272 deletions spacy_experimental/coref/coref_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from thinc.api import Model, chain, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp
from thinc.util import xp2torch, torch2xp

from spacy.tokens import Doc


EPSILON = 1e-7
try:
from .pytorch_coref_model import CorefClusterer
except ImportError:
CorefClusterer = None


def build_coref_model(
Expand All @@ -22,6 +24,9 @@ def build_coref_model(
antecedent_batch_size: int,
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:

if CorefClusterer is None:
raise ImportError("Coref requires PyTorch: pip install thinc[torch]")

nI = None

with Model.define_operators({">>": chain}):
Expand Down Expand Up @@ -111,272 +116,3 @@ def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
scores_xp = cast(Floats2d, torch2xp(scores))
indices_xp = cast(Ints2d, torch2xp(indices))
return (scores_xp, indices_xp), convert_for_torch_backward


class CorefClusterer(torch.nn.Module):
"""
Combines all coref modules together to find coreferent token pairs.
Submodules (in the order of their usage in the pipeline):
- rough_scorer (RoughScorer) that prunes candidate pairs
- pairwise (DistancePairwiseEncoder) that computes pairwise features
- ana_scorer (AnaphoricityScorer) produces the final scores
"""

def __init__(
self,
dim: int,
dist_emb_size: int,
hidden_size: int,
n_layers: int,
dropout: float,
rough_k: int,
batch_size: int,
):
super().__init__()
"""
dim: Size of the input features.
dist_emb_size: Size of the distance embeddings.
hidden_size: Size of the coreference candidate embeddings.
n_layers: Numbers of layers in the AnaphoricityScorer.
dropout: Dropout probability to apply across all modules.
rough_k: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive scorer.
"""
self.dropout = torch.nn.Dropout(dropout)
self.batch_size = batch_size
self.pairwise = DistancePairwiseEncoder(dist_emb_size, dropout)

pair_emb = dim * 3 + self.pairwise.shape
self.ana_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
batch_first=True,
)

self.rough_scorer = RoughScorer(dim, dropout, rough_k)

def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
1. LSTM encodes the incoming word_features.
2. The RoughScorer scores and prunes the candidates.
3. The DistancePairwiseEncoder embeds the distances between pairs.
4. The AnaphoricityScorer scores all pairs in mini-batches.
word_features: torch.Tensor containing word encodings
returns:
coref_scores: n_words x rough_k floats.
top_indices: n_words x rough_k integers.
"""
self.lstm.flatten_parameters() # XXX without this there's a warning
word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features)
words = words.squeeze()
# words: n_words x dim
words = self.dropout(words)
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores: (n_words x rough_k)
# top_indices: (n_words x rough_k)
top_rough_scores, top_indices = self.rough_scorer(words)
# Get pairwise features
# (n_words x rough_k x n_pairwise_features)
pairwise = self.pairwise(top_indices)
batch_size = self.batch_size
a_scores_lst: List[torch.Tensor] = []

for i in range(0, len(words), batch_size):
pairwise_batch = pairwise[i : i + batch_size]
words_batch = words[i : i + batch_size]
top_indices_batch = top_indices[i : i + batch_size]
top_rough_scores_batch = top_rough_scores[i : i + batch_size]

# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.ana_scorer(
all_mentions=words,
mentions_batch=words_batch,
pairwise_batch=pairwise_batch,
top_indices_batch=top_indices_batch,
top_rough_scores_batch=top_rough_scores_batch,
)
a_scores_lst.append(a_scores_batch)

coref_scores = torch.cat(a_scores_lst, dim=0)
return coref_scores, top_indices


# Note this function is kept here to keep a torch dep out of coref_util.
def add_dummy(tensor: torch.Tensor, eps: bool = False):
"""Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output


class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""

def __init__(self, in_features: int, hidden_size, depth, dropout):
super().__init__()
hidden_size = hidden_size
if not depth:
hidden_size = in_features
layers = []
for i in range(depth):
layers.extend(
[
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout),
]
)
self.hidden = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(hidden_size, out_features=1)

def forward(
self,
*, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pairwise_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
top_rough_scores_batch: torch.Tensor,
) -> torch.Tensor:
"""Builds a pairwise matrix, scores the pairs and returns the scores.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb]
mentions_batch (torch.Tensor): [batch_size, mention_emb]
pairwise_batch (torch.Tensor): [batch_size, n_ants, pairwise_emb]
top_indices_batch (torch.Tensor): [batch_size, n_ants]
top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
Returns:
torch.Tensor [batch_size, n_ants + 1]
anaphoricity scores for the pairs + a dummy column
"""
# [batch_size, n_ants, pair_emb]
pair_matrix = self._get_pair_matrix(
all_mentions, mentions_batch, pairwise_batch, top_indices_batch
)

# [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
scores = add_dummy(scores, eps=True)

return scores

def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
x: tensor of shape (batch_size x rough_k x n_features
returns: tensor of shape (batch_size x antecedent_limit)
"""
x = self.out(self.hidden(x))
return x.squeeze(2)

@staticmethod
def _get_pair_matrix(
all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pairwise_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
) -> torch.Tensor:
"""
Builds the matrix used as input for AnaphoricityScorer.
all_mentions: (n_mentions x mention_emb),
all the valid mentions of the document,
can be on a different device
mentions_batch: (batch_size x mention_emb),
the mentions of the current batch.
pairwise_batch: (batch_size x rough_k x pairwise_emb),
pairwise distance features of the current batch.
top_indices_batch: (batch_size x n_ants),
indices of antecedents of each mention
Returns:
out: pairwise features (batch_size x n_ants x pair_emb)
"""
emb_size = mentions_batch.shape[1]
n_ants = pairwise_batch.shape[1]

a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
b_mentions = all_mentions[top_indices_batch]
similarity = a_mentions * b_mentions

out = torch.cat((a_mentions, b_mentions, similarity, pairwise_batch), dim=2)
return out


class RoughScorer(torch.nn.Module):
"""
Cheaper module that gives a rough estimate of the anaphoricity of two
candidates, only top scoring candidates are considered on later
steps to reduce computational cost.
"""

def __init__(self, features: int, dropout: float, antecedent_limit: int):
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.bilinear = torch.nn.Linear(features, features)
self.k = antecedent_limit

def forward(
self, # type: ignore
mentions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns rough anaphoricity scores for candidates, which consist of
the bilinear output of the current model summed with mention scores.
"""
# [n_mentions, n_mentions]
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
top_scores, indices = torch.topk(
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
)

return top_scores, indices


class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, distance_embedding_size, dropout):
"""
Takes the top_indices indicating, which is a ranked
list for each word and its most likely corresponding
anaphora candidates. For each of these pairs it looks
up a distance embedding from a table, where the distance
corresponds to the log-distance.
distance_embedding_size: int,
Dimensionality of the distance-embeddings table.
dropout: float,
Dropout probability.
"""
super().__init__()
emb_size = distance_embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size)
self.dropout = torch.nn.Dropout(dropout)
self.shape = emb_size

def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = distance.to(top_indices.device)
distance = self.distance_emb(distance)
return self.dropout(distance)
2 changes: 0 additions & 2 deletions spacy_experimental/coref/coref_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import List, Tuple, Dict
from thinc.types import Ints1d, Ints2d, Floats2d
from thinc.api import NumpyOps
import srsly
from spacy.language import Language
from spacy.tokens import Doc
import spacy.util as util

# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
Expand Down
Loading

0 comments on commit fb3f24d

Please sign in to comment.