diff --git a/allennlp_hub/__init__.py b/allennlp_hub/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/allennlp_hub/pretrained.py b/allennlp_hub/pretrained.py deleted file mode 100644 index 9716895..0000000 --- a/allennlp_hub/pretrained.py +++ /dev/null @@ -1,170 +0,0 @@ -import warnings - -from allennlp import predictors -from allennlp.predictors import Predictor -from allennlp.models.archival import load_archive - - -class PretrainedModel: - """ - A pretrained model is determined by both an archive file - (representing the trained model) - and a choice of predictor. - """ - - def __init__(self, archive_file: str, predictor_name: str) -> None: - self.archive_file = archive_file - self.predictor_name = predictor_name - - def predictor(self) -> Predictor: - archive = load_archive(self.archive_file) - return Predictor.from_archive(archive, self.predictor_name) - - -# TODO(Mark): Figure out a way to make PretrainedModel generic on Predictor, so we can remove these type ignores. - -# Models in the demo - - -def srl_with_elmo_luheng_2018() -> predictors.SemanticRoleLabelerPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/srl-model-2018.05.25.tar.gz", - "semantic-role-labeling", - ) - return model.predictor() # type: ignore - - -def bert_srl_shi_2019() -> predictors.SemanticRoleLabelerPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz", - "semantic-role-labeling", - ) - return model.predictor() # type: ignore - - -def bidirectional_attention_flow_seo_2017() -> predictors.BidafPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/bidaf-model-2017.09.15-charpad.tar.gz", - "machine-comprehension", - ) - return model.predictor() # type: ignore - - -def naqanet_dua_2019() -> predictors.BidafPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/naqanet-2019.04.29-fixed-weight-names.tar.gz", - "machine-comprehension", - ) - return model.predictor() # type: ignore - - -def open_information_extraction_stanovsky_2018() -> predictors.OpenIePredictor: - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/openie-model.2018-08-20.tar.gz", - "open-information-extraction", - ) - return model.predictor() # type: ignore - - -def decomposable_attention_with_elmo_parikh_2017() -> predictors.DecomposableAttentionPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/decomposable-attention-elmo-2018.02.19.tar.gz", - "textual-entailment", - ) - return model.predictor() # type: ignore - - -def neural_coreference_resolution_lee_2017() -> predictors.CorefPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/coref-model-2018.02.05.tar.gz", - "coreference-resolution", - ) - predictor = model.predictor() - - predictor._dataset_reader._token_indexers[ # type: ignore - "token_characters" - ]._min_padding_length = 5 - return predictor # type: ignore - - -def named_entity_recognition_with_elmo_peters_2018() -> predictors.SentenceTaggerPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/ner-model-2018.12.18.tar.gz", - "sentence-tagger", - ) - predictor = model.predictor() - - predictor._dataset_reader._token_indexers[ # type: ignore - "token_characters" - ]._min_padding_length = 3 - return predictor # type: ignore - - -def fine_grained_named_entity_recognition_with_elmo_peters_2018() -> predictors.SentenceTaggerPredictor: - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/fine-grained-ner-model-elmo-2018.12.21.tar.gz", - "sentence-tagger", - ) - predictor = model.predictor() - - predictor._dataset_reader._token_indexers[ # type: ignore - "token_characters" - ]._min_padding_length = 3 - return predictor # type: ignore - - -def span_based_constituency_parsing_with_elmo_joshi_2018() -> predictors.ConstituencyParserPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/elmo-constituency-parser-2018.03.14.tar.gz", - "constituency-parser", - ) - return model.predictor() # type: ignore - - -def biaffine_parser_stanford_dependencies_todzat_2017() -> predictors.BiaffineDependencyParserPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/biaffine-dependency-parser-ptb-2018.08.23.tar.gz", - "biaffine-dependency-parser", - ) - return model.predictor() # type: ignore - - -# Models not in the demo - - -def biaffine_parser_universal_dependencies_todzat_2017() -> predictors.BiaffineDependencyParserPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/biaffine-dependency-parser-ud-2018.08.23.tar.gz", - "biaffine-dependency-parser", - ) - return model.predictor() # type: ignore - - -def esim_nli_with_elmo_chen_2017() -> predictors.DecomposableAttentionPredictor: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=DeprecationWarning) - model = PretrainedModel( - "https://allennlp.s3.amazonaws.com/models/esim-elmo-2018.05.17.tar.gz", - "textual-entailment", - ) - return model.predictor() # type: ignore diff --git a/allennlp_hub/pretrained/__init__.py b/allennlp_hub/pretrained/__init__.py new file mode 100644 index 0000000..2ea7fe2 --- /dev/null +++ b/allennlp_hub/pretrained/__init__.py @@ -0,0 +1,2 @@ +from allennlp_hub.pretrained.allennlp_pretrained import * +from allennlp_hub.pretrained.allennlp_semparse_pretrained import * diff --git a/allennlp_hub/pretrained/allennlp_pretrained.py b/allennlp_hub/pretrained/allennlp_pretrained.py new file mode 100644 index 0000000..fa0ec32 --- /dev/null +++ b/allennlp_hub/pretrained/allennlp_pretrained.py @@ -0,0 +1,122 @@ +from allennlp import predictors +from allennlp_hub.pretrained.helpers import _load_predictor +import allennlp.models + + +# Models in the main repo + + +def srl_with_elmo_luheng_2018() -> predictors.SemanticRoleLabelerPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/srl-model-2018.05.25.tar.gz", + "semantic-role-labeling", + ) + return predictor + + +def bert_srl_shi_2019() -> predictors.SemanticRoleLabelerPredictor: + predictor = _load_predictor( + "https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz", + "semantic-role-labeling", + ) + return predictor + + +def bidirectional_attention_flow_seo_2017() -> predictors.BidafPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/bidaf-model-2017.09.15-charpad.tar.gz", + "machine-comprehension", + ) + return predictor + + +def naqanet_dua_2019() -> predictors.BidafPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/naqanet-2019.04.29-fixed-weight-names.tar.gz", + "machine-comprehension", + ) + return predictor + + +def open_information_extraction_stanovsky_2018() -> predictors.OpenIePredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/openie-model.2018-08-20.tar.gz", + "open-information-extraction", + ) + return predictor + + +def decomposable_attention_with_elmo_parikh_2017() -> predictors.DecomposableAttentionPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/decomposable-attention-elmo-2018.02.19.tar.gz", + "textual-entailment", + ) + return predictor + + +def neural_coreference_resolution_lee_2017() -> predictors.CorefPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/coref-model-2018.02.05.tar.gz", + "coreference-resolution", + ) + + predictor._dataset_reader._token_indexers[ + "token_characters" + ]._min_padding_length = 5 + return predictor + + +def named_entity_recognition_with_elmo_peters_2018() -> predictors.SentenceTaggerPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/ner-model-2018.12.18.tar.gz", + "sentence-tagger", + ) + + predictor._dataset_reader._token_indexers[ + "token_characters" + ]._min_padding_length = 3 + return predictor + + +def fine_grained_named_entity_recognition_with_elmo_peters_2018() -> predictors.SentenceTaggerPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/fine-grained-ner-model-elmo-2018.12.21.tar.gz", + "sentence-tagger", + ) + + predictor._dataset_reader._token_indexers[ + "token_characters" + ]._min_padding_length = 3 + return predictor + + +def span_based_constituency_parsing_with_elmo_joshi_2018() -> predictors.ConstituencyParserPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/elmo-constituency-parser-2018.03.14.tar.gz", + "constituency-parser", + ) + return predictor + + +def biaffine_parser_stanford_dependencies_todzat_2017() -> predictors.BiaffineDependencyParserPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/biaffine-dependency-parser-ptb-2018.08.23.tar.gz", + "biaffine-dependency-parser", + ) + return predictor + + +def biaffine_parser_universal_dependencies_todzat_2017() -> predictors.BiaffineDependencyParserPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/biaffine-dependency-parser-ud-2018.08.23.tar.gz", + "biaffine-dependency-parser", + ) + return predictor + + +def esim_nli_with_elmo_chen_2017() -> predictors.DecomposableAttentionPredictor: + predictor = _load_predictor( + "https://allennlp.s3.amazonaws.com/models/esim-elmo-2018.05.17.tar.gz", + "textual-entailment", + ) + return predictor diff --git a/allennlp_hub/pretrained/allennlp_semparse_pretrained.py b/allennlp_hub/pretrained/allennlp_semparse_pretrained.py new file mode 100644 index 0000000..0d155c3 --- /dev/null +++ b/allennlp_hub/pretrained/allennlp_semparse_pretrained.py @@ -0,0 +1,38 @@ +from allennlp_hub.pretrained.helpers import _load_predictor +from allennlp_semparse import predictors as semparse_predictors +import allennlp_semparse.models + + +# AllenNLP Semparse models + + +def wikitables_parser_dasigi_2019() -> semparse_predictors.WikiTablesParserPredictor: + predictor = _load_predictor( + "https://storage.googleapis.com/allennlp-public-models/wikitables-model-2019.07.29.tar.gz", + "wikitables-parser", + ) + return predictor + + +def nlvr_parser_dasigi_2019() -> semparse_predictors.NlvrParserPredictor: + predictor = _load_predictor( + "https://storage.googleapis.com/allennlp-public-models/nlvr-erm-model-2018-12-18-rule-vocabulary-updated.tar.gz", + "nlvr-parser", + ) + return predictor + + +def atis_parser_lin_2019() -> semparse_predictors.AtisParserPredictor: + predictor = _load_predictor( + "https://storage.googleapis.com/allennlp-public-models/atis-parser-2018.11.10.tar.gz", + "atis-parser", + ) + return predictor + + +def quarel_parser_tafjord_2019() -> semparse_predictors.QuarelParserPredictor: + predictor = _load_predictor( + "https://storage.googleapis.com/allennlp-public-models/quarel-parser-zero-2018.12.20.tar.gz", + "quarel-parser", + ) + return predictor diff --git a/allennlp_hub/pretrained/helpers.py b/allennlp_hub/pretrained/helpers.py new file mode 100644 index 0000000..b83e580 --- /dev/null +++ b/allennlp_hub/pretrained/helpers.py @@ -0,0 +1,9 @@ +from allennlp.predictors import Predictor +from allennlp.models.archival import load_archive + +def _load_predictor(archive_file: str, predictor_name: str) -> Predictor: + """ + Helper to load the desired predictor from the given archive. + """ + archive = load_archive(archive_file) + return Predictor.from_archive(archive, predictor_name) diff --git a/setup.py b/setup.py index 97a6720..a16225a 100644 --- a/setup.py +++ b/setup.py @@ -34,10 +34,16 @@ # # As a mitigation, run `pip uninstall allennlp` before installing this # package. - # TODO(brendanr): Make these point to released versions. + # + # TODO(brendanr): Make these point to released versions. Currently + # allennlp-semparse is unreleased and it depends on a specific allennlp + # SHA. Due to the aforementioned setuptools bug, we explicitly set the + # allennlp version here to be that required by allennlp-semparse. + allennlp_sha = "93024e53c1445cb4630ee5c07926abff8943715f" + semparse_sha = "937d5945488a33c61d0047bd74d8106e60340bbd" install_requirements = [ - "allennlp @ git+ssh://git@github.com/allenai/allennlp@master#egg=allennlp", - "allennlp @ git+ssh://git@github.com/allenai/allennlp-semparse@master#egg=allennlp-semparse" + f"allennlp @ git+ssh://git@github.com/allenai/allennlp@{allennlp_sha}#egg=allennlp", + f"allennlp_semparse @ git+ssh://git@github.com/allenai/allennlp-semparse@{semparse_sha}#egg=allennlp-semparse", ] # make pytest-runner a conditional requirement, diff --git a/allennlp_hub/tests/sniff_test.py b/tests/pretrained/allennlp_pretrained_test.py similarity index 97% rename from allennlp_hub/tests/sniff_test.py rename to tests/pretrained/allennlp_pretrained_test.py index 7571703..71b9000 100644 --- a/allennlp_hub/tests/sniff_test.py +++ b/tests/pretrained/allennlp_pretrained_test.py @@ -5,9 +5,7 @@ from allennlp_hub import pretrained -class SniffTest(AllenNlpTestCase): - # TODO: Add semparse sniff tests. Were there ever any? - +class AllenNlpPretrainedTest(AllenNlpTestCase): def test_machine_comprehension(self): predictor = pretrained.bidirectional_attention_flow_seo_2017() @@ -16,9 +14,7 @@ def test_machine_comprehension(self): result = predictor.predict_json({"passage": passage, "question": question}) - correct = ( - "Keanu Reeves, Laurence Fishburne, Carrie-Anne Moss, Hugo Weaving, and Joe Pantoliano" - ) + correct = "Keanu Reeves, Laurence Fishburne, Carrie-Anne Moss, Hugo Weaving, and Joe Pantoliano" assert correct == result["best_span_str"] @@ -362,7 +358,9 @@ def test_ner(self): ] assert result["tags"] == ["B-PER", "L-PER", "O", "O", "O", "O", "U-LOC", "O"] - @pytest.mark.skipif(spacy.__version__ < "2.1", reason="this model changed from 2.0 to 2.1") + @pytest.mark.skipif( + spacy.__version__ < "2.1", reason="this model changed from 2.0 to 2.1" + ) def test_constituency_parsing(self): predictor = pretrained.span_based_constituency_parsing_with_elmo_joshi_2018() diff --git a/tests/pretrained/allennlp_semparse_pretrained_test.py b/tests/pretrained/allennlp_semparse_pretrained_test.py new file mode 100644 index 0000000..6e28c15 --- /dev/null +++ b/tests/pretrained/allennlp_semparse_pretrained_test.py @@ -0,0 +1,147 @@ +import pytest +import spacy + +from allennlp.common.testing import AllenNlpTestCase +from allennlp_hub import pretrained + + +class AllenNlpSemparsePretrainedTest(AllenNlpTestCase): + def test_wikitables_parser(self): + predictor = pretrained.wikitables_parser_dasigi_2019() + table = """# Event Year Season Flag bearer +7 2012 Summer Ele Opeloge +6 2008 Summer Ele Opeloge +5 2004 Summer Uati Maposua +4 2000 Summer Pauga Lalau +3 1996 Summer Bob Gasio +2 1988 Summer Henry Smith +1 1984 Summer Apelu Ioane""" + question = "How many years were held in summer?" + result = predictor.predict_json({"table": table, "question": question}) + assert result["answer"] == 7 + assert ( + result["logical_form"][0] + == "(count (filter_in all_rows string_column:season string:summer))" + ) + + def test_nlvr_parser(self): + predictor = pretrained.nlvr_parser_dasigi_2019() + structured_rep = """[ + [ + {"y_loc":13,"type":"square","color":"Yellow","x_loc":13,"size":20}, + {"y_loc":20,"type":"triangle","color":"Yellow","x_loc":44,"size":30}, + {"y_loc":90,"type":"circle","color":"#0099ff","x_loc":52,"size":10} + ], + [ + {"y_loc":57,"type":"square","color":"Black","x_loc":17,"size":20}, + {"y_loc":30,"type":"circle","color":"#0099ff","x_loc":76,"size":10}, + {"y_loc":12,"type":"square","color":"Black","x_loc":35,"size":10} + ], + [ + {"y_loc":40,"type":"triangle","color":"#0099ff","x_loc":26,"size":20}, + {"y_loc":70,"type":"triangle","color":"Black","x_loc":70,"size":30}, + {"y_loc":19,"type":"square","color":"Black","x_loc":35,"size":10} + ] + ]""" + sentence = "there is exactly one yellow object touching the edge" + result = predictor.predict_json( + {"structured_rep": structured_rep, "sentence": sentence} + ) + assert result["denotations"][0] == ["False"] + assert ( + result["logical_form"][0] + == "(object_count_equals (yellow (touch_wall all_objects)) 1)" + ) + + def test_atis_parser(self): + predictor = pretrained.atis_parser_lin_2019() + utterance = "give me flights on american airlines from milwaukee to phoenix" + result = predictor.predict_json({"utterance": utterance}) + predicted_sql_query = """ + (SELECT DISTINCT flight . flight_id + FROM flight + WHERE (flight . airline_code = 'AA' + AND (flight . from_airport IN + (SELECT airport_service . airport_code + FROM airport_service + WHERE airport_service . city_code IN + (SELECT city . city_code + FROM city + WHERE city . city_name = 'MILWAUKEE' ) ) + AND flight . to_airport IN + (SELECT airport_service . airport_code + FROM airport_service + WHERE airport_service . city_code IN + (SELECT city . city_code + FROM city + WHERE city . city_name = 'PHOENIX' ) ))) ) ;""" + assert result["predicted_sql_query"] == predicted_sql_query + + def test_quarel_parser(self): + predictor = pretrained.quarel_parser_tafjord_2019() + question = ( + "In his research, Joe is finding there is a lot more " + "diabetes in the city than out in the countryside. He " + "hypothesizes this is because people in _____ consume less " + "sugar. (A) city (B) countryside" + ) + qrspec = """[sugar, +diabetes] +[friction, -speed, -smoothness, -distance, +heat] +[speed, -time] +[speed, +distance] +[time, +distance] +[weight, -acceleration] +[strength, +distance] +[strength, +thickness] +[mass, +gravity] +[flexibility, -breakability] +[distance, -loudness, -brightness, -apparentSize] +[exerciseIntensity, +amountSweat]""" + entitycues = """friction: resistance, traction +speed: velocity, pace, fast, slow, faster, slower, slowly, quickly, rapidly +distance: length, way, far, near, further, longer, shorter, long, short, farther, furthest +heat: temperature, warmth, smoke, hot, hotter, cold, colder +smoothness: slickness, roughness, rough, smooth, rougher, smoother, bumpy, slicker +acceleration: +amountSweat: sweat, sweaty +apparentSize: size, large, small, larger, smaller +breakability: brittleness, brittle, break, solid +brightness: bright, shiny, faint +exerciseIntensity: excercise, run, walk +flexibility: flexible, stiff, rigid +gravity: +loudness: loud, faint, louder, fainter +mass: weight, heavy, light, heavier, lighter, massive +strength: power, strong, weak, stronger, weaker +thickness: thick, thin, thicker, thinner, skinny +time: long, short +weight: mass, heavy, light, heavier, lighter""" + result = predictor.predict_json( + {"question": question, "qrspec": qrspec, "entitycues": entitycues} + ) + assert result["answer"] == "B" + assert result["explanation"] == [ + { + "header": "Identified two worlds", + "content": ['world1 = "city"', 'world2 = "countryside"'], + }, + { + "header": "The question is stating", + "content": ['Diabetes is higher for "city"'], + }, + { + "header": "The answer options are stating", + "content": [ + 'A: Sugar is lower for "city"', + 'B: Sugar is lower for "countryside"', + ], + }, + { + "header": "Theory used", + "content": [ + 'When diabetes is higher then sugar is higher (for "city")', + 'Therefore sugar is lower for "countryside"', + "Therefore B is the correct answer", + ], + }, + ]