Skip to content

Commit

Permalink
Add ConsistentXLMModel (facebookresearch#913)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#913

Add an XLM model that accepts two text columns (and no label columns) as input. The model will evaluate (soft) predictions on the reference text input and treat that as the target distribution for the text in the "tokens" text input. This can be used for example when the two text columns are translations of each other (possibly multi-tasked with a regular labeled task).

Reviewed By: rutyrinott

Differential Revision: D16786687

fbshipit-source-id: a62d3ec4c27cc7e891375d38459ba1621e5b9d9d
  • Loading branch information
Michael Wu authored and facebook-github-bot committed Aug 17, 2019
1 parent e32c2a5 commit 90a00c3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pytext/metric_reporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .compositional_metric_reporter import CompositionalMetricReporter
from .intent_slot_detection_metric_reporter import IntentSlotMetricReporter
from .language_model_metric_reporter import LanguageModelMetricReporter
from .metric_reporter import MetricReporter
from .metric_reporter import MetricReporter, PureLossMetricReporter
from .pairwise_ranking_metric_reporter import PairwiseRankingMetricReporter
from .regression_metric_reporter import RegressionMetricReporter
from .squad_metric_reporter import SquadMetricReporter
Expand All @@ -32,4 +32,5 @@
"CompositionalMetricReporter",
"PairwiseRankingMetricReporter",
"SequenceTaggingMetricReporter",
"PureLossMetricReporter",
]
13 changes: 13 additions & 0 deletions pytext/metric_reporters/metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pytext.utils import cuda
from pytext.utils.meter import TimeMeter

from .channel import ConsoleChannel


class MetricReporter(Component):
"""
Expand Down Expand Up @@ -268,3 +270,14 @@ def compare_metric(self, new_metric, old_metric):
if new == old:
return False
return (new < old) == self.lower_is_better


class PureLossMetricReporter(MetricReporter):
lower_is_better = True

@classmethod
def from_config(cls, config, *args, **kwargs):
return cls([ConsoleChannel()], config.pep_format)

def calculate_metric(self):
return self.calculate_loss()
7 changes: 4 additions & 3 deletions pytext/task/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
LanguageModelMetricReporter,
MultiLabelClassificationMetricReporter,
PairwiseRankingMetricReporter,
PureLossMetricReporter,
RegressionMetricReporter,
SequenceTaggingMetricReporter,
SquadMetricReporter,
Expand Down Expand Up @@ -153,9 +154,9 @@ def format_prediction(cls, predictions, scores, context, target_meta):
class DocumentClassificationTask(NewTask):
class Config(NewTask.Config):
model: BaseModel.Config = DocModel.Config()
metric_reporter: ClassificationMetricReporter.Config = (
ClassificationMetricReporter.Config()
)
metric_reporter: Union[
ClassificationMetricReporter.Config, PureLossMetricReporter.Config
] = (ClassificationMetricReporter.Config())
# for multi-label classification task,
# choose MultiLabelClassificationMetricReporter

Expand Down

0 comments on commit 90a00c3

Please sign in to comment.