Skip to content

Commit

Permalink
Merge branch 'main' into fetch-meta-vars
Browse files Browse the repository at this point in the history
  • Loading branch information
BobBorges authored Oct 30, 2024
2 parents 09d2671 + f50f777 commit 998eeba
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyriksdagen"
version = "0.17.0"
version = "1.2.1"
description = "Access the Riksdagen corpus"
authors = ["ninpnin <vainoyrjanainen@icloud.com>"]
repository = "https://github.com/welfare-state-analytics/riksdagen-corpus"
Expand Down
13 changes: 11 additions & 2 deletions pyriksdagen/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def detect_mps(root, names_ids, pattern_db, mp_db=None, minister_db=None, minist

for tag, elem in elem_iter(root):
parent = elem.getparent()
if "type" not in parent.attrib or ("type" in parent.attrib and parent.attrib['type'] != "commentSection"): #ignore where people don't talk
if parent.attrib.get("type") != "commentSection": #ignore where people don't talk
if tag == "u":
# Deleting and adding attributes changes their order;
# Mark as 'delete' instead and delete later
Expand Down Expand Up @@ -195,7 +195,16 @@ def detect_mps(root, names_ids, pattern_db, mp_db=None, minister_db=None, minist

if current_speaker is None:
unknowns.append([protocol_id, elem.attrib.get(f'{xml_ns}id')] + [d.get(key, "") for key in unknown_variables])

else:
# If the whole section has no speeches, reset speaker and next/prev notation
if tag == "u":
elem.set("prev", "delete")
elem.set("next", "delete")
elem.set("who", "unknown")

current_speaker = None
prev = None

# Do two loops to preserve attribute order
for tag, elem in elem_iter(root):
if tag == "u":
Expand Down
76 changes: 76 additions & 0 deletions pyriksdagen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,79 @@ def get_gh_link(_file,
return gh


def remove_whitespace_from_sequence(text):
"""
Remove repeated whitespace and replace all whitespace with spaces
Input is string and output is string.
"""
text_seq = text.split()
text_seq = [s for s in text_seq if s != '']
return ' '.join(text_seq)


def get_sequence_from_elem_list(elem_list):
"""
Get sequence from first elem in list.
Returns string. If list is empty, returns empty string.
"""
if len(elem_list) > 0:
return str(elem_list[0].text)
return ""


def extract_context_sequence(elem, context_type, target_length = 128, separator = '/n'):
"""
Get sequence with context from xml element. Returns string.
"""
sequence_to_list_by_punctuation = lambda sequence_string: list(filter(None, re.split(r'([.!?])', sequence_string)))

current_sequence = remove_whitespace_from_sequence(elem.text)

previous_elem_list = elem.xpath("preceding::*[local-name() = 'note' or local-name() = 'seg'][1]")
previous_sequence = remove_whitespace_from_sequence(get_sequence_from_elem_list(previous_elem_list))
previous_sequence_as_list = sequence_to_list_by_punctuation(previous_sequence)
previous_last_sentence = ''.join(previous_sequence_as_list[-2:]).lstrip('.!? ')

if context_type == 'left_context':
max_previous_length = target_length//2
elif context_type == 'full_context':
max_previous_length = target_length//3
next_elem_list = elem.xpath("following::*[local-name() = 'note' or local-name() = 'seg'][1]")
next_sequence = remove_whitespace_from_sequence(get_sequence_from_elem_list(next_elem_list))
next_sequence_as_list = sequence_to_list_by_punctuation(next_sequence)
next_first_sentence = ''.join(next_sequence_as_list[:2])

previous_last_sentence = ' '.join(previous_last_sentence.split(' ')[-max_previous_length:]) # truncate sequence if too long
left_context_sequence = previous_last_sentence + f' {separator} ' + current_sequence

if context_type == 'left_context':
return left_context_sequence
elif context_type == 'full_context':
return left_context_sequence + f' {separator} ' + next_first_sentence


def get_context_sequences_for_protocol(protocol, context_type, target_length = 128, separator = '/n'):
"""
Gets context sequences for a protocol. Returns dictionary with ids and corresponding context sequences.
"""
id_list, texts_with_contexts = [], []

parser = etree.XMLParser(remove_blank_text=True)
root = etree.parse(protocol, parser).getroot()

for tag, elem in elem_iter(root):
if tag == 'note':
elem_id = elem.get(f'{XML_NS}id')
id_list.append(elem_id)
context_sequence = extract_context_sequence(elem, context_type = context_type, target_length = target_length, separator = separator)
texts_with_contexts.append(context_sequence)
elif tag == 'u':
for child in elem:
child_id = child.get(f'{XML_NS}id')
id_list.append(child_id)
context_sequence = extract_context_sequence(child, context_type=context_type, target_length = target_length, separator = separator)
texts_with_contexts.append(context_sequence)

output_dict = {'id' : id_list,
'text' : texts_with_contexts}
return output_dict

0 comments on commit 998eeba

Please sign in to comment.