Skip to content

Commit

Permalink
Fix performance drop when creating track flyweights
Browse files Browse the repository at this point in the history
  • Loading branch information
randy-seng committed Aug 20, 2024
1 parent 233e2bb commit 4e1e241
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
23 changes: 13 additions & 10 deletions OTAnalytics/plugin_datastore/track_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def frame(self) -> int:

@property
def occurrence(self) -> datetime:
return self._occurrence[1]
return self._occurrence

@property
def interpolated_detection(self) -> bool:
Expand Down Expand Up @@ -282,7 +282,7 @@ def as_generator(self) -> Generator[Track, None, None]:
yield from []
else:
for current in track_ids.array:
yield self.__create_track_flyweight(current)
yield self._create_track_flyweight(current)

@staticmethod
def from_list(
Expand Down Expand Up @@ -354,7 +354,7 @@ def get_for(self, id: TrackId) -> Optional[Track]:
if self._dataset.empty:
return None
try:
return self.__create_track_flyweight(id.id)
return self._create_track_flyweight(id.id)
except KeyError:
return None

Expand Down Expand Up @@ -393,11 +393,15 @@ def as_list(self) -> list[Track]:
"Creating track flyweight objects which is really slow in "
f"'{PandasTrackDataset.as_list.__name__}'."
)
return [self.__create_track_flyweight(current) for current in track_ids]
return [self._create_track_flyweight(current) for current in track_ids]

def __create_track_flyweight(self, track_id: str) -> Track:
track_frame = self._dataset.loc[[track_id], :]
return PandasTrack(track_id, track_frame)
def _create_track_flyweight(self, track_id: str) -> Track:
track_frame = self._dataset.loc[track_id, :]
if isinstance(track_frame, DataFrame):
return PandasTrack(track_id, track_frame)
if isinstance(track_frame, Series):
return PandasTrack(track_id, track_frame.to_frame(track_id))
raise NotImplementedError(f"Not implemented for {type(track_frame)}")

def get_data(self) -> DataFrame:
return self._dataset
Expand Down Expand Up @@ -444,10 +448,9 @@ def __len__(self) -> int:
return len(self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique())

def filter_by_min_detection_length(self, length: int) -> "PandasTrackDataset":
# groupby.size always returns a series
detection_counts_per_track: Series[int] = self._dataset.groupby( # type: ignore
detection_counts_per_track: Series[int] = self._dataset.groupby(
level=LEVEL_TRACK_ID
).size()
)[track.FRAME].size()
filtered_ids = detection_counts_per_track[
detection_counts_per_track >= length
].index
Expand Down
21 changes: 15 additions & 6 deletions tests/OTAnalytics/plugin_datastore/test_track_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_properties(self) -> None:
python_detection = builder.build_detections()[0]
data = Series(
python_detection.to_dict(),
name=(python_detection.track_id.id, python_detection.occurrence),
name=python_detection.occurrence,
)
pandas_detection = PandasDetection(python_detection.track_id.id, data)

Expand All @@ -77,12 +77,9 @@ def test_properties(self) -> None:
builder.append_detection()
python_track = builder.build_track()
detections = [detection.to_dict() for detection in python_track.detections]
data = (
DataFrame(detections)
.set_index([track.TRACK_ID, track.OCCURRENCE])
.sort_index()
)
data = DataFrame(detections).set_index([track.OCCURRENCE]).sort_index()
data[track.TRACK_CLASSIFICATION] = data[track.CLASSIFICATION]
data = data.drop([track.TRACK_ID], axis=1)
pandas_track = PandasTrack(python_track.id.id, data)

assert_equal_track_properties(pandas_track, python_track)
Expand Down Expand Up @@ -628,3 +625,15 @@ def test_get_max_confidences_for(

result = filled_dataset.get_max_confidences_for([car_id, pedestrian_id])
assert result == {car_id: 0.8, pedestrian_id: 0.9}

def test_create_test_flyweight_with_single_detection(
self, track_geometry_factory: TRACK_GEOMETRY_FACTORY
) -> None:
track_builder = TrackBuilder()
track_builder.append_detection()
single_detection_track = track_builder.build_track()
dataset = PandasTrackDataset.from_list(
[single_detection_track], track_geometry_factory
)
result = dataset._create_track_flyweight(single_detection_track.id.id)
assert_equal_track_properties(result, single_detection_track)

0 comments on commit 4e1e241

Please sign in to comment.