Skip to content

Commit

Permalink
improve typing in preprocess_aif
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Oct 18, 2024
1 parent 7ef7eb2 commit 5df9c48
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
2 changes: 1 addition & 1 deletion arguebuf/load/_load_aif.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing as t
from arguebuf.load.preprocess.preprocess_aif import process_hanging_nodes
from arguebuf.load._preprocess_aif import process_hanging_nodes
import pendulum
from arguebuf import dt
from arguebuf.model import Graph, utils
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from dataclasses import dataclass, field
from typing import List

import pendulum
from arguebuf.schemas.aif import DATE_FORMAT
from arguebuf.model import Graph, utils

from arguebuf.model import utils
from arguebuf.schemas import aif


@dataclass
@dataclass(slots=True)
class NewNode:
node_id: int
node_id: str
text: str
node_type: str
timestamp: float
incoming_edges: List = field(default_factory=list)
outgoing_edges: List = field(default_factory=list)
timestamp: str
incoming_edges: list[aif.Edge] = field(default_factory=list)
outgoing_edges: list[aif.Edge] = field(default_factory=list)

def add_incoming_edge(self, edge):
def add_incoming_edge(self, edge: aif.Edge):
self.incoming_edges.append(edge)

def add_outgoing_edge(self, edge):
def add_outgoing_edge(self, edge: aif.Edge):
self.outgoing_edges.append(edge)


def get_connected_nodes_of_type(node: NewNode, node_type: str, allnodes: dict[NewNode]):
def get_connected_nodes_of_type(
node: NewNode, node_type: str, allnodes: dict[str, NewNode]
) -> list[NewNode]:
"""Returns a list of nodes connected to the given node (either incoming or outgoing) of a specific type."""
connected_node_ids = [edge["fromID"] for edge in node.incoming_edges] + [
edge["toID"] for edge in node.outgoing_edges
Expand All @@ -34,8 +37,8 @@ def get_connected_nodes_of_type(node: NewNode, node_type: str, allnodes: dict[Ne


def get_connected_nodes_excluding_types(
node: NewNode, excluded_types: list[str], allnodes: dict[NewNode]
):
node: NewNode, excluded_types: list[str], allnodes: dict[str, NewNode]
) -> list[NewNode]:
"""Returns a list of nodes connected to the given node (either incoming or outgoing) excluding specific types."""
connected_node_ids = [edge["fromID"] for edge in node.incoming_edges] + [
edge["toID"] for edge in node.outgoing_edges
Expand All @@ -47,7 +50,7 @@ def get_connected_nodes_excluding_types(
]


def are_nodes_connected(node1: NewNode, node2: NewNode):
def are_nodes_connected(node1: NewNode, node2: NewNode) -> bool:
"""Returns True if node1 and node2 are connected (either node1 to node2 or node2 to node1), otherwise False."""
return any(
edge["fromID"] == node1.node_id and edge["toID"] == node2.node_id
Expand All @@ -59,8 +62,8 @@ def are_nodes_connected(node1: NewNode, node2: NewNode):


def remove_loops_in_hanging_nodes(
hanging_nodes: dict[NewNode], obj: Graph, allnodes: dict[NewNode]
):
hanging_nodes: list[NewNode], obj: aif.Graph, allnodes: dict[str, NewNode]
) -> None:
"""
Removes loops in the graph that are formed specifically by connecting hanging nodes.
Steps:
Expand Down Expand Up @@ -107,11 +110,11 @@ def remove_loops_in_hanging_nodes(


def process_each_hanging_node(
hanging_nodes: dict[NewNode],
obj: Graph,
allnodes: dict[NewNode],
updated_hanging_nodes: dict[NewNode],
):
hanging_nodes: list[NewNode],
obj: aif.Graph,
allnodes: dict[str, NewNode],
updated_hanging_nodes: list[NewNode],
) -> aif.Graph:
"""
Finds and connects plausible connected Nodes based on specific criteria related to dialogues and transitions.
Expand All @@ -135,7 +138,7 @@ def process_each_hanging_node(
# 1)
# Iterate over hanging nodes
# all generated "Rephrase" Nodes should have the same datetime to filter them out for postprocessing
similar_datetime = pendulum.now().format(DATE_FORMAT)
similar_datetime = pendulum.now().format(aif.DATE_FORMAT)
for hanging_node in hanging_nodes:
# Check if the hanging node has exactly one incoming edge
if len(hanging_node.incoming_edges) == 1:
Expand Down Expand Up @@ -196,30 +199,30 @@ def process_each_hanging_node(
hanging_node, argument_node
):
new_node_id = utils.uuid()
new_node = {
new_node: aif.Node = {
"nodeID": new_node_id,
"text": "Default Rephrase",
"type": "MA",
"timestamp": similar_datetime,
"scheme": "Default Rephrase",
"schemeID": "144",
}
obj["nodes"].append(new_node)

# Create edges connecting hanging_node -> new_node and new_node -> ArgumentNode
new_edge_1_id = utils.uuid()
new_edge_1 = {
new_edge_1: aif.Edge = {
"edgeID": new_edge_1_id,
"fromID": hanging_node.node_id,
"toID": new_node_id,
"formEdgeID": None,
}
obj["edges"].append(new_edge_1)

new_edge_2_id = utils.uuid()
new_edge_2 = {
new_edge_2: aif.Edge = {
"edgeID": new_edge_2_id,
"fromID": new_node_id,
"toID": argument_node.node_id,
"formEdgeID": None,
}
obj["edges"].append(new_edge_2)
# Update the edges in the copied hanging nodes list
Expand All @@ -241,11 +244,11 @@ def process_each_hanging_node(
return obj


def create_and_process_nodes(obj: Graph):
def create_and_process_nodes(obj: aif.Graph) -> dict[str, NewNode]:
"""
creates and processes all Nodes, so they contain every ingoing/outgoing edge
"""
nodes = {}
nodes: dict[str, NewNode] = {}
for aif_node in obj["nodes"]:
node_id = aif_node["nodeID"]
node_text = aif_node.get("text", "")
Expand All @@ -258,7 +261,12 @@ def create_and_process_nodes(obj: Graph):
edge_id = aif_edge["edgeID"]
from_id = aif_edge["fromID"]
to_id = aif_edge["toID"]
edge_data = {"edgeID": edge_id, "fromID": from_id, "toID": to_id}
edge_data: aif.Edge = {
"edgeID": edge_id,
"fromID": from_id,
"toID": to_id,
"formEdgeID": None,
}

if from_id in nodes:
nodes[from_id].add_outgoing_edge(edge_data)
Expand All @@ -267,7 +275,7 @@ def create_and_process_nodes(obj: Graph):
return nodes


def find_hanging_nodes(allnodes: dict[NewNode]):
def find_hanging_nodes(allnodes: dict[str, NewNode]) -> list[NewNode]:
"""
Identifies nodes within a graph that have no outgoing edges and exactly one incoming edge, excluding nodes of type "L".
Returns:
Expand All @@ -283,7 +291,7 @@ def find_hanging_nodes(allnodes: dict[NewNode]):
return hanging_nodes


def process_hanging_nodes(obj: Graph):
def process_hanging_nodes(obj: aif.Graph) -> aif.Graph:
"""
Finds and connects specific Nodes which aren't connected to other ArgumentNodes yet
"""
Expand Down
Empty file.

0 comments on commit 5df9c48

Please sign in to comment.