diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index fed307871..c96d38bc9 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -78,7 +78,7 @@ def frame(self) -> int: @property def occurrence(self) -> datetime: - return self._occurrence[1] + return self._occurrence @property def interpolated_detection(self) -> bool: @@ -282,7 +282,7 @@ def as_generator(self) -> Generator[Track, None, None]: yield from [] else: for current in track_ids.array: - yield self.__create_track_flyweight(current) + yield self._create_track_flyweight(current) @staticmethod def from_list( @@ -354,7 +354,7 @@ def get_for(self, id: TrackId) -> Optional[Track]: if self._dataset.empty: return None try: - return self.__create_track_flyweight(id.id) + return self._create_track_flyweight(id.id) except KeyError: return None @@ -393,11 +393,15 @@ def as_list(self) -> list[Track]: "Creating track flyweight objects which is really slow in " f"'{PandasTrackDataset.as_list.__name__}'." ) - return [self.__create_track_flyweight(current) for current in track_ids] + return [self._create_track_flyweight(current) for current in track_ids] - def __create_track_flyweight(self, track_id: str) -> Track: - track_frame = self._dataset.loc[[track_id], :] - return PandasTrack(track_id, track_frame) + def _create_track_flyweight(self, track_id: str) -> Track: + track_frame = self._dataset.loc[track_id, :] + if isinstance(track_frame, DataFrame): + return PandasTrack(track_id, track_frame) + if isinstance(track_frame, Series): + return PandasTrack(track_id, track_frame.to_frame(track_id)) + raise NotImplementedError(f"Not implemented for {type(track_frame)}") def get_data(self) -> DataFrame: return self._dataset @@ -444,10 +448,9 @@ def __len__(self) -> int: return len(self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique()) def filter_by_min_detection_length(self, length: int) -> "PandasTrackDataset": - # groupby.size always returns a series - detection_counts_per_track: Series[int] = self._dataset.groupby( # type: ignore + detection_counts_per_track: Series[int] = self._dataset.groupby( level=LEVEL_TRACK_ID - ).size() + )[track.FRAME].size() filtered_ids = detection_counts_per_track[ detection_counts_per_track >= length ].index diff --git a/tests/OTAnalytics/plugin_datastore/test_track_store.py b/tests/OTAnalytics/plugin_datastore/test_track_store.py index 3bfadb851..b56d8dc9d 100644 --- a/tests/OTAnalytics/plugin_datastore/test_track_store.py +++ b/tests/OTAnalytics/plugin_datastore/test_track_store.py @@ -60,7 +60,7 @@ def test_properties(self) -> None: python_detection = builder.build_detections()[0] data = Series( python_detection.to_dict(), - name=(python_detection.track_id.id, python_detection.occurrence), + name=python_detection.occurrence, ) pandas_detection = PandasDetection(python_detection.track_id.id, data) @@ -77,12 +77,9 @@ def test_properties(self) -> None: builder.append_detection() python_track = builder.build_track() detections = [detection.to_dict() for detection in python_track.detections] - data = ( - DataFrame(detections) - .set_index([track.TRACK_ID, track.OCCURRENCE]) - .sort_index() - ) + data = DataFrame(detections).set_index([track.OCCURRENCE]).sort_index() data[track.TRACK_CLASSIFICATION] = data[track.CLASSIFICATION] + data = data.drop([track.TRACK_ID], axis=1) pandas_track = PandasTrack(python_track.id.id, data) assert_equal_track_properties(pandas_track, python_track) @@ -628,3 +625,15 @@ def test_get_max_confidences_for( result = filled_dataset.get_max_confidences_for([car_id, pedestrian_id]) assert result == {car_id: 0.8, pedestrian_id: 0.9} + + def test_create_test_flyweight_with_single_detection( + self, track_geometry_factory: TRACK_GEOMETRY_FACTORY + ) -> None: + track_builder = TrackBuilder() + track_builder.append_detection() + single_detection_track = track_builder.build_track() + dataset = PandasTrackDataset.from_list( + [single_detection_track], track_geometry_factory + ) + result = dataset._create_track_flyweight(single_detection_track.id.id) + assert_equal_track_properties(result, single_detection_track)