From 5423b0b07fe8e3dda00bcb5514d3857b90c3acaf Mon Sep 17 00:00:00 2001 From: James Dunham Date: Sun, 3 Sep 2017 16:47:42 -0400 Subject: [PATCH] add corresponding changes to neuroner.py --- src/neuroner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/neuroner.py b/src/neuroner.py index cb25314a..f5c96cdb 100644 --- a/src/neuroner.py +++ b/src/neuroner.py @@ -79,6 +79,7 @@ def _load_parameters(self, parameters_filepath, arguments={}, verbose=True): 'reload_token_lstm':True, 'remap_unknown_tokens_to_unk':True, 'spacylanguage':'en', + 'split_discontinuous':False, 'tagging_format':'bioes', 'token_embedding_dimension':100, 'token_lstm_hidden_state_dimension':100, @@ -114,7 +115,8 @@ def _load_parameters(self, parameters_filepath, arguments={}, verbose=True): parameters[k] = float(v) elif k in ['remap_unknown_tokens_to_unk', 'use_character_lstm', 'use_crf', 'train_model', 'use_pretrained_model', 'debug', 'verbose', 'reload_character_embeddings', 'reload_character_lstm', 'reload_token_embeddings', 'reload_token_lstm', 'reload_feedforward', 'reload_crf', - 'check_for_lowercase', 'check_for_digits_replaced_with_zeros', 'freeze_token_embeddings', 'load_only_pretrained_token_embeddings', 'load_all_pretrained_token_embeddings']: + 'check_for_lowercase', 'check_for_digits_replaced_with_zeros', 'freeze_token_embeddings', 'load_only_pretrained_token_embeddings', 'load_all_pretrained_token_embeddings', + 'split_discontinuous']: parameters[k] = distutils.util.strtobool(v) # If loading pretrained model, set the model hyperparameters according to the pretraining parameters if parameters['use_pretrained_model']: @@ -147,7 +149,7 @@ def _get_valid_dataset_filepaths(self, parameters, dataset_types=['train', 'vali if os.path.exists(dataset_brat_folders[dataset_type]) and len(glob.glob(os.path.join(dataset_brat_folders[dataset_type], '*.txt'))) > 0: # Check compatibility between conll and brat files - brat_to_conll.check_brat_annotation_and_text_compatibility(dataset_brat_folders[dataset_type]) + brat_to_conll.check_brat_annotation_and_text_compatibility(dataset_brat_folders[dataset_type], parameters['split_discontinuous']) if os.path.exists(dataset_compatible_with_brat_filepath): dataset_filepaths[dataset_type] = dataset_compatible_with_brat_filepath conll_to_brat.check_compatibility_between_conll_and_brat_text(dataset_filepaths[dataset_type], dataset_brat_folders[dataset_type]) @@ -168,7 +170,8 @@ def _get_valid_dataset_filepaths(self, parameters, dataset_types=['train', 'vali conll_to_brat.check_compatibility_between_conll_and_brat_text(dataset_filepath_for_tokenizer, dataset_brat_folders[dataset_type]) else: # Populate conll file based on brat files - brat_to_conll.brat_to_conll(dataset_brat_folders[dataset_type], dataset_filepath_for_tokenizer, parameters['tokenizer'], parameters['spacylanguage']) + brat_to_conll.brat_to_conll(dataset_brat_folders[dataset_type], + dataset_filepath_for_tokenizer, parameters['tokenizer'], parameters['spacylanguage'], parameters['split_discontinuous']) dataset_filepaths[dataset_type] = dataset_filepath_for_tokenizer # Brat text files do not exist @@ -238,6 +241,7 @@ def __init__(self, reload_token_lstm=argument_default_value, remap_unknown_tokens_to_unk=argument_default_value, spacylanguage=argument_default_value, + split_discontinuous=argument_default_value, tagging_format=argument_default_value, token_embedding_dimension=argument_default_value, token_lstm_hidden_state_dimension=argument_default_value, @@ -475,7 +479,7 @@ def predict(self, text): # Print and output result text_filepath = os.path.join(self.stats_graph_folder, 'brat', 'deploy', os.path.basename(dataset_brat_deploy_filepath)) annotation_filepath = os.path.join(self.stats_graph_folder, 'brat', 'deploy', '{0}.ann'.format(utils.get_basename_without_extension(dataset_brat_deploy_filepath))) - text2, entities = brat_to_conll.get_entities_from_brat(text_filepath, annotation_filepath, verbose=True) + text2, entities = brat_to_conll.get_entities_from_brat(text_filepath, annotation_filepath, split_discontinuous, verbose=True) assert(text == text2) return entities