From a34d8c0a45c90974e0b602e272bf3c3cc17cb7dc Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 24 Aug 2023 11:09:20 +0200 Subject: [PATCH] [#2402] Filter events before assignment to flows https://openproject.platomo.de/work_packages/2402 --- .../application/analysis/traffic_counting.py | 48 ++++++++++++++++++- OTAnalytics/plugin_ui/main_application.py | 6 ++- .../analysis/test_traffic_counting.py | 21 +++++++- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/OTAnalytics/application/analysis/traffic_counting.py b/OTAnalytics/application/analysis/traffic_counting.py index f2d3735ec..fa5509a9b 100644 --- a/OTAnalytics/application/analysis/traffic_counting.py +++ b/OTAnalytics/application/analysis/traffic_counting.py @@ -14,6 +14,7 @@ from OTAnalytics.domain.flow import Flow, FlowRepository from OTAnalytics.domain.section import SectionId from OTAnalytics.domain.track import TrackId, TrackRepository +from OTAnalytics.domain.types import EventType LEVEL_FLOW = "flow" LEVEL_CLASSIFICATION = "classification" @@ -479,7 +480,52 @@ def __repr__(self) -> str: return RoadUserAssignments.__name__ + repr(self._assignments) -class RoadUserAssigner: +class RoadUserAssigner(ABC): + """ + Class to assign tracks to flows. + """ + + @abstractmethod + def assign(self, events: Iterable[Event], flows: list[Flow]) -> RoadUserAssignments: + """ + Assign each track to exactly one flow. + + Args: + events (Iterable[Event]): events to be used during assignment + flows (list[Flow]): flows to assign tracks to + + Returns: + RoadUserAssignments: group of RoadUserAssignment objects + """ + raise NotImplementedError + + +class RoadUserAssignerDecorator(RoadUserAssigner): + """ + Decorator class for RoadUserAssigner. + + Args: + other: the RoadUserAssigner to be decorated. + """ + + def __init__(self, other: RoadUserAssigner) -> None: + self._other = other + + def assign(self, events: Iterable[Event], flows: list[Flow]) -> RoadUserAssignments: + return self._other.assign(events, flows) + + +class FilterBySectionEnterEvent(RoadUserAssignerDecorator): + """Decorator to filters events by event type section-enter.""" + + def assign(self, events: Iterable[Event], flows: list[Flow]) -> RoadUserAssignments: + section_enter_events: list[Event] = [ + event for event in events if event.event_type == EventType.SECTION_ENTER + ] + return super().assign(section_enter_events, flows) + + +class SimpleRoadUserAssigner(RoadUserAssigner): """ Class to assign tracks to flows. """ diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index a3d39bc38..6fff67b96 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -8,7 +8,9 @@ ) from OTAnalytics.application.analysis.traffic_counting import ( ExportTrafficCounting, + FilterBySectionEnterEvent, RoadUserAssigner, + SimpleRoadUserAssigner, SimpleTaggerFactory, ) from OTAnalytics.application.analysis.traffic_counting_specification import ExportCounts @@ -191,7 +193,7 @@ def start_gui(self) -> None: track_view_state = self._create_track_view_state() section_state = self._create_section_state() flow_state = self._create_flow_state() - road_user_assigner = RoadUserAssigner() + road_user_assigner = FilterBySectionEnterEvent(SimpleRoadUserAssigner()) pandas_data_provider = self._create_pandas_data_provider( datastore, track_view_state, pulling_progressbar_builder @@ -800,7 +802,7 @@ def _create_export_counts( return ExportTrafficCounting( event_repository, flow_repository, - RoadUserAssigner(), + FilterBySectionEnterEvent(SimpleRoadUserAssigner()), SimpleTaggerFactory(track_repository), FillZerosExporterFactory(SimpleExporterFactory()), ) diff --git a/tests/OTAnalytics/application/analysis/test_traffic_counting.py b/tests/OTAnalytics/application/analysis/test_traffic_counting.py index bfac36835..aa8400b39 100644 --- a/tests/OTAnalytics/application/analysis/test_traffic_counting.py +++ b/tests/OTAnalytics/application/analysis/test_traffic_counting.py @@ -19,11 +19,13 @@ ExporterFactory, ExportTrafficCounting, FillEmptyCount, + FilterBySectionEnterEvent, ModeTagger, MultiTag, RoadUserAssigner, RoadUserAssignment, RoadUserAssignments, + SimpleRoadUserAssigner, SingleTag, Tag, TaggedAssignments, @@ -518,7 +520,22 @@ def create_assignment_test_cases() -> ( return TestCaseBuilder().build_assignment_test_cases() -class TestRoadUserAssigner: +class TestFilterBySectionEnterEvent: + def test_assign_filters_by_section_enter_event(self) -> None: + enter_scene_event = Mock(spec=Event) + enter_scene_event.event_type = EventType.ENTER_SCENE + section_enter_event = Mock(spec=Event) + section_enter_event.event_type = EventType.SECTION_ENTER + flow = Mock(spec=Flow) + + road_user_assigner = Mock(spec=RoadUserAssigner) + + decorator = FilterBySectionEnterEvent(road_user_assigner) + decorator.assign([enter_scene_event, section_enter_event], [flow]) + road_user_assigner.assign.assert_called_once_with([section_enter_event], [flow]) + + +class TestSimpleRoadUserAssigner: @pytest.mark.parametrize( "events, flows, expected_result", create_assignment_test_cases() ) @@ -528,7 +545,7 @@ def test_run( flows: list[Flow], expected_result: RoadUserAssignments, ) -> None: - analysis = RoadUserAssigner() + analysis = SimpleRoadUserAssigner() result = analysis.assign(events, flows) assert result == expected_result