Skip to content

Commit

Permalink
Merge pull request #11089 from polm/coref/dimension-inference
Browse files Browse the repository at this point in the history
Dimension inference in Coref
  • Loading branch information
polm committed Jul 12, 2022
2 parents 0f3c456 + 07e8556 commit 90973fa
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 34 deletions.
50 changes: 44 additions & 6 deletions spacy/ml/models/coref.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Tuple, Callable, cast

from thinc.api import Model, chain
from thinc.api import Model, chain, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints2d
from thinc.util import torch, xp2torch, torch2xp
Expand All @@ -22,13 +22,48 @@ def build_wl_coref_model(
# pairs to keep per mention after rough scoring
antecedent_limit: int = 50,
antecedent_batch_size: int = 512,
tok2vec_size: int = 768, # tok2vec size
nI=None,
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:

with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper(
coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
"coref_clusterer",
forward=coref_forward,
init=coref_init,
dims={"nI": nI},
attrs={
"distance_embedding_size": distance_embedding_size,
"hidden_size": hidden_size,
"depth": depth,
"dropout": dropout,
"antecedent_limit": antecedent_limit,
"antecedent_batch_size": antecedent_batch_size,
},
)

model = tok2vec >> coref_clusterer
model.set_ref("coref_clusterer", coref_clusterer)
return model


def coref_init(model: Model, X=None, Y=None):
if model.layers:
return

if X is not None and model.has_dim("nI") is None:
model.set_dim("nI", get_width(X))

hidden_size = model.attrs["hidden_size"]
depth = model.attrs["depth"]
dropout = model.attrs["dropout"]
antecedent_limit = model.attrs["antecedent_limit"]
antecedent_batch_size = model.attrs["antecedent_batch_size"]
distance_embedding_size = model.attrs["distance_embedding_size"]

model._layers = [
PyTorchWrapper(
CorefClusterer(
tok2vec_size,
model.get_dim("nI"),
distance_embedding_size,
hidden_size,
depth,
Expand All @@ -39,9 +74,12 @@ def build_wl_coref_model(
convert_inputs=convert_coref_clusterer_inputs,
convert_outputs=convert_coref_clusterer_outputs,
)
coref_model = tok2vec >> coref_clusterer
return coref_model
# TODO maybe we need mixed precision and grad scaling?
]


def coref_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train)

def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc
Expand Down
4 changes: 3 additions & 1 deletion spacy/ml/models/coref_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
ints are char spans, to be tokenization independent.
"""
out = []
for key, val in doc.spans.items():
keys = sorted(list(doc.spans.keys()))
for key in keys:
val = doc.spans[key]
cluster = []
for span in val:

Expand Down
55 changes: 47 additions & 8 deletions spacy/ml/models/span_predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Tuple, cast

from thinc.api import Model, chain, tuplify
from thinc.api import Model, chain, tuplify, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d
from thinc.util import torch, xp2torch, torch2xp
Expand All @@ -13,7 +13,6 @@
@registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]],
tok2vec_size: int = 768,
hidden_size: int = 1024,
distance_embedding_size: int = 64,
conv_channels: int = 4,
Expand All @@ -23,10 +22,46 @@ def build_span_predictor(
):
# TODO add model return types

nI = None

with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper(
span_predictor: Model[List[Floats2d], List[Floats2d]] = Model(
"span_predictor",
forward=span_predictor_forward,
init=span_predictor_init,
dims={"nI": nI},
attrs={
"distance_embedding_size": distance_embedding_size,
"hidden_size": hidden_size,
"conv_channels": conv_channels,
"window_size": window_size,
"max_distance": max_distance,
},
)
head_info = build_get_head_metadata(prefix)
model = (tok2vec & head_info) >> span_predictor
model.set_ref("span_predictor", span_predictor)

return model


def span_predictor_init(model: Model, X=None, Y=None):
if model.layers:
return

if X is not None and model.has_dim("nI") is None:
model.set_dim("nI", get_width(X))

hidden_size = model.attrs["hidden_size"]
distance_embedding_size = model.attrs["distance_embedding_size"]
conv_channels = model.attrs["conv_channels"]
window_size = model.attrs["window_size"]
max_distance = model.attrs["max_distance"]

model._layers = [
PyTorchWrapper(
SpanPredictor(
tok2vec_size,
model.get_dim("nI"),
hidden_size,
distance_embedding_size,
conv_channels,
Expand All @@ -35,10 +70,12 @@ def build_span_predictor(
),
convert_inputs=convert_span_predictor_inputs,
)
head_info = build_get_head_metadata(prefix)
model = (tok2vec & head_info) >> span_predictor
# TODO maybe we need mixed precision and grad scaling?
]

return model

def span_predictor_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train)


def convert_span_predictor_inputs(
Expand All @@ -61,7 +98,9 @@ def backprop(args: ArgsKwargs) -> Tuple[List[Floats2d], None]:
else:
head_ids_tensor = xp2torch(head_ids[0], requires_grad=False)

argskwargs = ArgsKwargs(args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={})
argskwargs = ArgsKwargs(
args=(sent_ids_tensor, word_features, head_ids_tensor), kwargs={}
)
return argskwargs, backprop


Expand Down
58 changes: 56 additions & 2 deletions spacy/pipeline/coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
from statistics import mean
import srsly

from .trainable_pipe import TrainablePipe
from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..tokens import Doc
from ..vocab import Vocab
from ..util import registry
from ..util import registry, from_disk, from_bytes

from ..ml.models.coref_util import (
create_gold_scores,
Expand All @@ -30,7 +31,6 @@
default_config = """
[model]
@architectures = "spacy.Coref.v1"
tok2vec_size = 768
distance_embedding_size = 20
hidden_size = 1024
depth = 1
Expand Down Expand Up @@ -340,3 +340,57 @@ def initialize(

assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)

# Store the input dimensionality. nI and nO are not stored explicitly
# for PyTorch models. This makes it tricky to reconstruct the model
# during deserialization. So, besides storing the labels, we also
# store the number of inputs.
coref_clusterer = self.model.get_ref("coref_clusterer")
self.cfg["nI"] = coref_clusterer.get_dim("nI")

def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
}
from_bytes(bytes_data, deserializers, exclude)

self._initialize_from_disk()

model_deserializers = {
"model": lambda b: self.model.from_bytes(b),
}
from_bytes(bytes_data, model_deserializers, exclude)

return self

def from_disk(self, path, exclude=tuple()):
def load_model(p):
try:
with open(p, "rb") as mfile:
self.model.from_bytes(mfile.read())
except AttributeError:
raise ValueError(Errors.E149) from None

deserializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
}
from_disk(path, deserializers, exclude)

self._initialize_from_disk()

model_deserializers = {
"model": load_model,
}
from_disk(path, model_deserializers, exclude)

return self

def _initialize_from_disk(self):
# The PyTorch model is constructed lazily, so we need to
# explicitly initialize the model before deserialization.
model = self.model.get_ref("coref_clusterer")
if model.has_dim("nI") is None:
model.set_dim("nI", self.cfg["nI"])
self.model.initialize()
58 changes: 56 additions & 2 deletions spacy/pipeline/span_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
import srsly

from .trainable_pipe import TrainablePipe
from ..language import Language
Expand All @@ -13,7 +14,7 @@
from ..scorer import Scorer, doc2clusters
from ..tokens import Doc
from ..vocab import Vocab
from ..util import registry
from ..util import registry, from_bytes, from_disk

from ..ml.models.coref_util import (
MentionClusters,
Expand All @@ -23,7 +24,6 @@
default_span_predictor_config = """
[model]
@architectures = "spacy.SpanPredictor.v1"
tok2vec_size = 768
hidden_size = 1024
distance_embedding_size = 64
conv_channels = 4
Expand Down Expand Up @@ -346,3 +346,57 @@ def initialize(

assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)

# Store the input dimensionality. nI and nO are not stored explicitly
# for PyTorch models. This makes it tricky to reconstruct the model
# during deserialization. So, besides storing the labels, we also
# store the number of inputs.
span_predictor = self.model.get_ref("span_predictor")
self.cfg["nI"] = span_predictor.get_dim("nI")

def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
}
from_bytes(bytes_data, deserializers, exclude)

self._initialize_from_disk()

model_deserializers = {
"model": lambda b: self.model.from_bytes(b),
}
from_bytes(bytes_data, model_deserializers, exclude)

return self

def from_disk(self, path, exclude=tuple()):
def load_model(p):
try:
with open(p, "rb") as mfile:
self.model.from_bytes(mfile.read())
except AttributeError:
raise ValueError(Errors.E149) from None

deserializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
}
from_disk(path, deserializers, exclude)

self._initialize_from_disk()

model_deserializers = {
"model": load_model,
}
from_disk(path, model_deserializers, exclude)

return self

def _initialize_from_disk(self):
# The PyTorch model is constructed lazily, so we need to
# explicitly initialize the model before deserialization.
model = self.model.get_ref("span_predictor")
if model.has_dim("nI") is None:
model.set_dim("nI", self.cfg["nI"])
self.model.initialize()
Loading

0 comments on commit 90973fa

Please sign in to comment.