Skip to content

Commit

Permalink
chore: merge pull request #8 from swerik-project/add_context_functions
Browse files Browse the repository at this point in the history
add context sequence functions
  • Loading branch information
ninpnin authored Aug 9, 2024
2 parents c653801 + 9fb4fde commit 6b7c006
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions pyriksdagen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import zipfile
import os
from trainerlog import get_logger
import re

LOGGER = get_logger("pyriksdagen")
XML_NS = "{http://www.w3.org/XML/1998/namespace}"
Expand Down Expand Up @@ -334,3 +335,76 @@ def get_data_location(partition):
d["metadata"] = os.environ.get("METADATA_PATH", "data")
return d[partition]

def remove_whitespace_from_sequence(text):
"""
Remove repeated whitespace and replace all whitespace with spaces
Input is string and output is string.
"""
text_seq = text.split()
text_seq = [s for s in text_seq if s != '']
return ' '.join(text_seq)

def get_sequence_from_elem_list(elem_list):
"""
Get sequence from first elem in list.
Returns string. If list is empty, returns empty string.
"""
if len(elem_list) > 0:
return str(elem_list[0].text)
return ""

def extract_context_sequence(elem, context_type, target_length = 128, separator = '/n'):
"""
Get sequence with context from xml element. Returns string.
"""
sequence_to_list_by_punctuation = lambda sequence_string: list(filter(None, re.split(r'([.!?])', sequence_string)))

current_sequence = remove_whitespace_from_sequence(elem.text)

previous_elem_list = elem.xpath("preceding::*[local-name() = 'note' or local-name() = 'seg'][1]")
previous_sequence = remove_whitespace_from_sequence(get_sequence_from_elem_list(previous_elem_list))
previous_sequence_as_list = sequence_to_list_by_punctuation(previous_sequence)
previous_last_sentence = ''.join(previous_sequence_as_list[-2:]).lstrip('.!? ')

if context_type == 'left_context':
max_previous_length = target_length//2
elif context_type == 'full_context':
max_previous_length = target_length//3
next_elem_list = elem.xpath("following::*[local-name() = 'note' or local-name() = 'seg'][1]")
next_sequence = remove_whitespace_from_sequence(get_sequence_from_elem_list(next_elem_list))
next_sequence_as_list = sequence_to_list_by_punctuation(next_sequence)
next_first_sentence = ''.join(next_sequence_as_list[:2])

previous_last_sentence = ' '.join(previous_last_sentence.split(' ')[-max_previous_length:]) # truncate sequence if too long
left_context_sequence = previous_last_sentence + f' {separator} ' + current_sequence

if context_type == 'left_context':
return left_context_sequence
elif context_type == 'full_context':
return left_context_sequence + f' {separator} ' + next_first_sentence

def get_context_sequences_for_protocol(protocol, context_type, target_length = 128, separator = '/n'):
"""
Gets context sequences for a protocol. Returns dictionary with ids and corresponding context sequences.
"""
id_list, texts_with_contexts = [], []

parser = etree.XMLParser(remove_blank_text=True)
root = etree.parse(protocol, parser).getroot()

for tag, elem in elem_iter(root):
if tag == 'note':
elem_id = elem.get(f'{XML_NS}id')
id_list.append(elem_id)
context_sequence = extract_context_sequence(elem, context_type = context_type, target_length = target_length, separator = separator)
texts_with_contexts.append(context_sequence)
elif tag == 'u':
for child in elem:
child_id = child.get(f'{XML_NS}id')
id_list.append(child_id)
context_sequence = extract_context_sequence(child, context_type=context_type, target_length = target_length, separator = separator)
texts_with_contexts.append(context_sequence)

output_dict = {'id' : id_list,
'text' : texts_with_contexts}
return output_dict

0 comments on commit 6b7c006

Please sign in to comment.