Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Move default predictors #31

Merged
merged 5 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions allennlp_models/coref/coref_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,3 +871,5 @@ def _compute_coreference_scores(
# Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
coreference_scores = torch.cat([dummy_scores, antecedent_scores], -1)
return coreference_scores

default_predictor = "coreference-resolution"
2 changes: 2 additions & 0 deletions allennlp_models/ner/crf_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
else:
metrics_to_return.update({x: y for x, y in f1_dict.items() if "overall" in x})
return metrics_to_return

default_predictor = "sentence-tagger"
1 change: 1 addition & 0 deletions allennlp_models/nli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from allennlp_models.nli.snli_reader import SnliReader
from allennlp_models.nli.decomposable_attention_model import DecomposableAttention
from allennlp_models.nli.decomposable_attention_predictor import DecomposableAttentionPredictor
from allennlp_models.nli.bimpm_model import BiMpm
from allennlp_models.nli.esim_model import ESIM
from allennlp_models.nli.quora_paraphrase_reader import QuoraParaphraseDatasetReader
2 changes: 2 additions & 0 deletions allennlp_models/nli/bimpm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {
metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()
}

default_predictor = "textual-entailment"
2 changes: 2 additions & 0 deletions allennlp_models/nli/decomposable_attention_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,5 @@ def forward( # type: ignore

def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {"accuracy": self._accuracy.get_metric(reset)}

default_predictor = "textual-entailment"
58 changes: 58 additions & 0 deletions allennlp_models/nli/decomposable_attention_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List, Dict
from copy import deepcopy

import numpy
from overrides import overrides

from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors.predictor import Predictor
from allennlp.data.fields import LabelField


@Predictor.register("textual-entailment")
class DecomposableAttentionPredictor(Predictor):
"""
Predictor for the [`DecomposableAttention`](../models/decomposable_attention.md) model.

Registered as a `Predictor` with name "textual-entailment".
"""

def predict(self, premise: str, hypothesis: str) -> JsonDict:
"""
Predicts whether the hypothesis is entailed by the premise text.

# Parameters

premise : `str`
A passage representing what is assumed to be true.

hypothesis : `str`
A sentence that may be entailed by the premise.

# Returns

`JsonDict`
A dictionary where the key "label_probs" determines the probabilities of each of
[entailment, contradiction, neutral].
"""
return self.predict_json({"premise": premise, "hypothesis": hypothesis})

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Expects JSON that looks like `{"premise": "...", "hypothesis": "..."}`.
"""
premise_text = json_dict["premise"]
hypothesis_text = json_dict["hypothesis"]
return self._dataset_reader.text_to_instance(premise_text, hypothesis_text)

@overrides
def predictions_to_labeled_instances(
self, instance: Instance, outputs: Dict[str, numpy.ndarray]
) -> List[Instance]:
new_instance = deepcopy(instance)
label = numpy.argmax(outputs["label_logits"])
# Skip indexing, we have integer representations of the strings "entailment", etc.
new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
return [new_instance]
2 changes: 2 additions & 0 deletions allennlp_models/rc/bidaf/bidaf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,5 @@ def get_best_span(
span_start_indices = best_spans // passage_length
span_end_indices = best_spans % passage_length
return torch.stack([span_start_indices, span_end_indices], dim=-1)

default_predictor = "reading-comprehension"
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

POS_TO_IGNORE = {"`", "''", ":", ",", ".", "PU", "PUNCT", "SYM"}

# exist_ok has to be true until we remove this from the core library
@Model.register("biaffine_parser", exist_ok=True)

@Model.register("biaffine_parser")
class BiaffineDependencyParser(Model):
"""
This dependency parser follows the model of
Expand Down Expand Up @@ -681,3 +681,5 @@ def _get_mask_for_eval(
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return self._attachment_scores.get_metric(reset)

default_predictor = "biaffine-dependency-parser"
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]:
evalb_metrics = self._evalb_score.get_metric(reset=reset)
all_metrics.update(evalb_metrics)
return all_metrics

default_predictor = "constituency-parser"
2 changes: 2 additions & 0 deletions allennlp_models/syntax/srl/srl_bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,5 @@ def get_start_transitions(self):
start_transitions[i] = float("-inf")

return start_transitions

default_predictor = "semantic-role-labeling"
2 changes: 2 additions & 0 deletions allennlp_models/syntax/srl/srl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def get_start_transitions(self):

return start_transitions

default_predictor = "semantic-role-labeling"


def write_to_conll_eval_file(
prediction_file: TextIO,
Expand Down