Skip to content

Commit

Permalink
add corresponding changes to neuroner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesdunham committed Sep 3, 2017
1 parent 54d2c3b commit 5423b0b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/neuroner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _load_parameters(self, parameters_filepath, arguments={}, verbose=True):
'reload_token_lstm':True,
'remap_unknown_tokens_to_unk':True,
'spacylanguage':'en',
'split_discontinuous':False,
'tagging_format':'bioes',
'token_embedding_dimension':100,
'token_lstm_hidden_state_dimension':100,
Expand Down Expand Up @@ -114,7 +115,8 @@ def _load_parameters(self, parameters_filepath, arguments={}, verbose=True):
parameters[k] = float(v)
elif k in ['remap_unknown_tokens_to_unk', 'use_character_lstm', 'use_crf', 'train_model', 'use_pretrained_model', 'debug', 'verbose',
'reload_character_embeddings', 'reload_character_lstm', 'reload_token_embeddings', 'reload_token_lstm', 'reload_feedforward', 'reload_crf',
'check_for_lowercase', 'check_for_digits_replaced_with_zeros', 'freeze_token_embeddings', 'load_only_pretrained_token_embeddings', 'load_all_pretrained_token_embeddings']:
'check_for_lowercase', 'check_for_digits_replaced_with_zeros', 'freeze_token_embeddings', 'load_only_pretrained_token_embeddings', 'load_all_pretrained_token_embeddings',
'split_discontinuous']:
parameters[k] = distutils.util.strtobool(v)
# If loading pretrained model, set the model hyperparameters according to the pretraining parameters
if parameters['use_pretrained_model']:
Expand Down Expand Up @@ -147,7 +149,7 @@ def _get_valid_dataset_filepaths(self, parameters, dataset_types=['train', 'vali
if os.path.exists(dataset_brat_folders[dataset_type]) and len(glob.glob(os.path.join(dataset_brat_folders[dataset_type], '*.txt'))) > 0:

# Check compatibility between conll and brat files
brat_to_conll.check_brat_annotation_and_text_compatibility(dataset_brat_folders[dataset_type])
brat_to_conll.check_brat_annotation_and_text_compatibility(dataset_brat_folders[dataset_type], parameters['split_discontinuous'])
if os.path.exists(dataset_compatible_with_brat_filepath):
dataset_filepaths[dataset_type] = dataset_compatible_with_brat_filepath
conll_to_brat.check_compatibility_between_conll_and_brat_text(dataset_filepaths[dataset_type], dataset_brat_folders[dataset_type])
Expand All @@ -168,7 +170,8 @@ def _get_valid_dataset_filepaths(self, parameters, dataset_types=['train', 'vali
conll_to_brat.check_compatibility_between_conll_and_brat_text(dataset_filepath_for_tokenizer, dataset_brat_folders[dataset_type])
else:
# Populate conll file based on brat files
brat_to_conll.brat_to_conll(dataset_brat_folders[dataset_type], dataset_filepath_for_tokenizer, parameters['tokenizer'], parameters['spacylanguage'])
brat_to_conll.brat_to_conll(dataset_brat_folders[dataset_type],
dataset_filepath_for_tokenizer, parameters['tokenizer'], parameters['spacylanguage'], parameters['split_discontinuous'])
dataset_filepaths[dataset_type] = dataset_filepath_for_tokenizer

# Brat text files do not exist
Expand Down Expand Up @@ -238,6 +241,7 @@ def __init__(self,
reload_token_lstm=argument_default_value,
remap_unknown_tokens_to_unk=argument_default_value,
spacylanguage=argument_default_value,
split_discontinuous=argument_default_value,
tagging_format=argument_default_value,
token_embedding_dimension=argument_default_value,
token_lstm_hidden_state_dimension=argument_default_value,
Expand Down Expand Up @@ -475,7 +479,7 @@ def predict(self, text):
# Print and output result
text_filepath = os.path.join(self.stats_graph_folder, 'brat', 'deploy', os.path.basename(dataset_brat_deploy_filepath))
annotation_filepath = os.path.join(self.stats_graph_folder, 'brat', 'deploy', '{0}.ann'.format(utils.get_basename_without_extension(dataset_brat_deploy_filepath)))
text2, entities = brat_to_conll.get_entities_from_brat(text_filepath, annotation_filepath, verbose=True)
text2, entities = brat_to_conll.get_entities_from_brat(text_filepath, annotation_filepath, split_discontinuous, verbose=True)
assert(text == text2)
return entities

Expand Down

0 comments on commit 5423b0b

Please sign in to comment.