Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: remove .apply in predict_transport_mode #596

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions tests/analysis/test_label.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os

import numpy as np
import pandas as pd
import pytest

import trackintel as ti
from trackintel.analysis.labelling import _check_categories


class TestCreate_activity_flag:
Expand Down Expand Up @@ -100,14 +98,3 @@ def test_simple_coarse_identification_projected(self):
assert tpls_transport_mode_3.iloc[0]["mode"] == "slow_mobility"
assert tpls_transport_mode_3.iloc[1]["mode"] == "motorized_mobility"
assert tpls_transport_mode_3.iloc[2]["mode"] == "fast_mobility"

def test_check_categories(self):
"""Asserts the correct identification of valid category dictionaries."""
tpls_file = os.path.join("tests", "data", "triplegs_transport_mode_identification.csv")
tpls = ti.read_triplegs_csv(tpls_file, sep=";", index_col="id")
correct_dict = {2: "cat1", 7: "cat2", np.inf: "cat3"}

assert _check_categories(correct_dict)
with pytest.raises(ValueError):
incorrect_dict = {10: "cat1", 5: "cat2", np.inf: "cat3"}
tpls.as_triplegs.predict_transport_mode(method="simple-coarse", categories=incorrect_dict)
80 changes: 15 additions & 65 deletions trackintel/analysis/labelling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime

import numpy as np
import pandas as pd

from trackintel.geogr import get_speed_triplegs

Expand Down Expand Up @@ -81,90 +82,39 @@ def predict_transport_mode(triplegs, method="simple-coarse", **kwargs):
categories = kwargs.pop(
"categories", {15 / 3.6: "slow_mobility", 100 / 3.6: "motorized_mobility", np.inf: "fast_mobility"}
)

return _predict_transport_mode_simple_coarse(triplegs, categories)
triplegs = triplegs.copy()
triplegs["mode"] = _predict_transport_mode_simple_coarse(triplegs, categories)
return triplegs
else:
raise AttributeError(f"Method {method} not known for predicting tripleg transport modes.")


def _predict_transport_mode_simple_coarse(triplegs_in, categories):
def _predict_transport_mode_simple_coarse(triplegs, categories):
"""
Predict a transport mode out of three coarse classes.
Predict a transport mode based on provided categories.

Implements a simple speed based heuristic (over the whole tripleg).
As such, it is very fast, but also very simple and coarse.

Parameters
----------
triplegs_in : Triplegs
triplegs : Triplegs
The triplegs for the transport mode prediction.

categories : dict, optional
The categories for the speed classification {upper_boundary:'category_name'}.
The categories for the speed classification {upper_boundary: 'category_name'}.
The unit for the upper boundary is m/s.
The default is {15/3.6: 'slow_mobility', 100/3.6: 'motorized_mobility', np.inf: 'fast_mobility'}.

Raises
------
ValueError
In case the boundaries of the categories are not in ascending order.

Returns
-------
triplegs : trackintel triplegs GeoDataFrame
the triplegs with added column mode, containing the predicted transport modes.
cuts : pd.Series
Column containing the predicted transport modes.

For additional documentation, see
:func:`trackintel.analysis.transport_mode_identification.predict_transport_mode`.

"""
if not (_check_categories(categories)):
raise ValueError("the categories must be in increasing order")

triplegs = triplegs_in.copy()

def category_by_speed(speed):
"""
Identify the mode based on the (overall) tripleg speed.

Parameters
----------
speed : float
the speed of one tripleg

Returns
-------
str
the identified mode.
"""
for bound in categories:
if speed < bound:
return categories[bound]

triplegs_speed = get_speed_triplegs(triplegs)

triplegs["mode"] = triplegs_speed["speed"].apply(category_by_speed)
return triplegs


def _check_categories(cat):
"""
Check if the keys of a dictionary are in ascending order.

Parameters
----------
cat : disct
the dictionary to be checked.

Returns
-------
correct : bool
True if dict keys are in ascending order False otherwise.

"""
correct = True
bounds = list(cat.keys())
for i in range(len(bounds) - 1):
if bounds[i] >= bounds[i + 1]:
correct = False
return correct
categories = dict(sorted(categories.items(), key=lambda item: item[0]))
intervals = pd.IntervalIndex.from_breaks([-np.inf] + list(categories.keys()), closed="left")
speed = get_speed_triplegs(triplegs)["speed"]
cuts = pd.cut(speed, intervals)
return cuts.cat.rename_categories(categories.values())
Loading