Skip to content

Commit

Permalink
model_selection/last_n_split (#33)
Browse files Browse the repository at this point in the history
* added LastNSplitter

* fixed review suggestions

* lint fixed

* correct example

* added splitter base class

* fixed lint

* n now may be int or iterable

* n now may be int or iterable

* added LastNSplitter

* added negative N error

* fixed review mistakes

* fixed review mistakes

* attempt to avoid copying of df

* pd.DataFrame changed to pd.Series

* simplified calculations

* simplified calculations

* corrected processing of complicated index part1

* added test cases with unusual index

* temporary solution for index problem

* returned shuffle of test interactions

* updated docstring example

* updated docstring example

* rewrited test_complicated_index
  • Loading branch information
yukeeul authored May 22, 2023
1 parent eee3ba5 commit 9b3992e
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 31 deletions.
4 changes: 3 additions & 1 deletion rectools/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@
Instruments to validate and compare models.
Splitters
---------
`model_selection.Splitter` - base class for all splitters
`model_selection.KFoldSplitter` - split interactions randomly
`model_selection.LastNSplitter` - split interactions by recent activity
`model_selection.TimeRangeSplit` - split interactions by time
"""

from .kfold_split import KFoldSplitter
from .last_n_split import LastNSplitter
from .splitter import Splitter
from .time_split import TimeRangeSplitter

__all__ = (
"Splitter",
"KFoldSplitter",
"LastNSplitter",
"TimeRangeSplitter",
)
27 changes: 12 additions & 15 deletions rectools/model_selection/kfold_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class KFoldSplitter(Splitter):
Relative size of test part, must be between 0. and 1.
n_splits : int, default 1
Number of folds.
random_state: int, default None,
random_state : int, default None,
Controls randomness of each fold. Pass an int to get reproducible result across multiple `split` calls.
filter_cold_users: bool, default ``True``
filter_cold_users : bool, default ``True``
If `True`, users that not in train will be excluded from test.
filter_cold_items: bool, default ``True``
filter_cold_items : bool, default ``True``
If `True`, items that not in train will be excluded from test.
filter_already_seen: bool, default ``True``
filter_already_seen : bool, default ``True``
If `True`, pairs (user, item) that are in train will be excluded from test.
Examples
Expand All @@ -69,15 +69,15 @@ class KFoldSplitter(Splitter):
... filter_cold_items=False, filter_already_seen=False)
>>> for train_ids, test_ids, _ in kfs.split(interactions):
... print(train_ids, test_ids)
[0 1 2 5 6 7] [3 4]
[0 1 3 4 5 6] [2 7]
[2 7 6 1 5 0] [3 4]
[3 4 6 1 5 0] [2 7]
>>>
>>> kfs = KFoldSplitter(test_size=0.25, random_state=42, n_splits=2, filter_cold_users=True,
... filter_cold_items=True, filter_already_seen=True)
>>> for train_ids, test_ids, _ in kfs.split(interactions):
... print(train_ids, test_ids)
[0 1 2 5 6 7] [3 4]
[0 1 3 4 5 6] [2]
[2 7 6 1 5 0] [3 4]
[3 4 6 1 5 0] [2]
"""

def __init__(
Expand Down Expand Up @@ -124,12 +124,9 @@ def _split_without_filter(
shuffled_idx = rng.permutation(idx)
for i in range(self.n_splits):
fold_info = {"fold_number": i}
test_mask = np.zeros_like(idx, dtype=bool)
chosen_idx = shuffled_idx[i * test_part_size : (i + 1) * test_part_size]
test_mask[chosen_idx] = True
train_mask = ~test_mask

train_idx = idx[train_mask].values
test_idx = idx[test_mask].values
left = i * test_part_size
right = (i + 1) * test_part_size
test_idx = shuffled_idx[left:right]
train_idx = np.concatenate((shuffled_idx[:left], shuffled_idx[right:]))

yield train_idx, test_idx, fold_info
127 changes: 127 additions & 0 deletions rectools/model_selection/last_n_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""LastNSplitter."""

import typing as tp

import numpy as np
import pandas as pd

from rectools.dataset import Interactions
from rectools.model_selection.splitter import Splitter


class LastNSplitter(Splitter):
"""
Splitter for cross-validation by recent activity.
Generate train and test putting last n interaction for
each user in test and others in train.
It is also possible to exclude cold users and items
and already seen items.
Parameters
----------
n : int or iterable of ints
Number of interactions for each user that will be included in test.
If multiple arguments are passed, separate fold will be created for each of them.
filter_cold_users : bool, default ``True``
If `True`, users that not in train will be excluded from test.
filter_cold_items : bool, default ``True``
If `True`, items that not in train will be excluded from test.
filter_already_seen : bool, default ``True``
If ``True``, pairs (user, item) that are in train will be excluded from test.
Examples
--------
>>> from rectools import Columns
>>> df = pd.DataFrame(
... [
... [1, 1, 1, "2021-09-01"], # 0
... [1, 2, 1, "2021-09-02"], # 1
... [1, 1, 1, "2021-08-20"], # 2
... [1, 2, 1, "2021-09-04"], # 3
... [2, 1, 1, "2021-08-20"], # 4
... [2, 2, 1, "2021-08-20"], # 5
... [2, 3, 1, "2021-09-05"], # 6
... [2, 2, 1, "2021-09-06"], # 7
... [3, 1, 1, "2021-09-05"], # 8
... ],
... columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
... ).astype({Columns.Datetime: "datetime64[ns]"})
>>> interactions = Interactions(df)
>>>
>>> lns = LastNSplitter(2, False, False, False)
>>> for train_ids, test_ids, _ in lns.split(interactions):
... print(train_ids, test_ids)
[0 2 4 5] [1 3 6 7 8]
>>>
>>> lns = LastNSplitter(2, True, True, True)
>>> for train_ids, test_ids, _ in lns.split(interactions):
... print(train_ids, test_ids)
[0 2 4 5] [1 3]
>>>
>>> lns = LastNSplitter([1, 2], False, False, False)
>>> for train_ids, test_ids, _ in lns.split(interactions):
... print(train_ids, test_ids)
[0 1 2 4 5 6] [3 7 8]
[0 2 4 5] [1 3 6 7 8]
"""

def __init__(
self,
n: tp.Union[int, tp.Iterable[int]],
filter_cold_users: bool = True,
filter_cold_items: bool = True,
filter_already_seen: bool = True,
) -> None:
super().__init__(filter_cold_users, filter_cold_items, filter_already_seen)
if isinstance(n, int):
self.n = [n]
else:
self.n = list(n)

def _split_without_filter(
self,
interactions: Interactions,
collect_fold_stats: bool = False,
) -> tp.Iterator[tp.Tuple[np.ndarray, np.ndarray, tp.Dict[str, tp.Any]]]:
df = interactions.df
idx = pd.RangeIndex(0, len(df))
index_has_duplicates = df.index.has_duplicates
if index_has_duplicates:
df = df[["user_id", "datetime"]].reset_index(drop=True)
else:
index_df = pd.Series(idx, index=df.index)

for n in self.n:
if n <= 0:
raise ValueError(f"N must be positive, got {n}")

last_n_interactions = df.groupby("user_id")["datetime"].nlargest(n)
if index_has_duplicates:
test_idx = last_n_interactions.index.levels[1].to_numpy()
else:
test_idx_remapped = last_n_interactions.index.levels[1].to_numpy()
test_idx = index_df.loc[test_idx_remapped].values

train_mask = np.ones_like(idx, dtype=bool)
train_mask[test_idx] = False
train_idx = idx[train_mask].values

fold_info = {}
if collect_fold_stats:
fold_info["n"] = n

yield train_idx, test_idx, fold_info
18 changes: 9 additions & 9 deletions rectools/model_selection/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def split(
Parameters
----------
interactions: Interactions
interactions : Interactions
User-item interactions.
collect_fold_stats: bool, default False
collect_fold_stats : bool, default False
Add some stats to fold info,
like size of train and test part, number of users and items.
Expand All @@ -73,9 +73,9 @@ def _split_without_filter(
Parameters
----------
interactions: Interactions
interactions : Interactions
User-item interactions.
collect_fold_stats: bool, default False
collect_fold_stats : bool, default False
Add some stats to fold info,
like size of train and test part, number of users and items.
Expand All @@ -101,16 +101,16 @@ def filter(
Parameters
----------
interactions: Interactions
interactions : Interactions
User-item interactions.
collect_fold_stats: bool, default False
collect_fold_stats : bool, default False
Add some stats to fold info,
like size of train and test part, number of users and items.
train_idx: array
train_idx : array
Train part row numbers.
test_idx: array
test_idx : array
Test part row numbers.
fold_info: dict
fold_info : dict
Information about fold.
Returns
Expand Down
8 changes: 4 additions & 4 deletions rectools/model_selection/time_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ class TimeRangeSplitter(Splitter):
Parameters
----------
date_range: array-like(date|datetime)
date_range : array-like(date|datetime)
Ordered test fold borders.
Left will be included, right will be excluded from fold.
Interactions before first border will be used for train.
Interaction after right border will not be used.
Ca be easily generated with [`pd.date_range`]
(https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.date_range.html)
filter_cold_users: bool, default ``True``
filter_cold_users : bool, default ``True``
If `True`, users that not in train will be excluded from test.
filter_cold_items: bool, default ``True``
filter_cold_items : bool, default ``True``
If `True`, items that not in train will be excluded from test.
filter_already_seen: bool, default ``True``
filter_already_seen : bool, default ``True``
If ``True``, pairs (user, item) that are in train will be excluded from test.
Examples
Expand Down
4 changes: 2 additions & 2 deletions rectools/model_selection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def get_not_seen_mask(
Returns
-------
np.ndarray
Boolean mask of same length as `test_users` (`test_items`).
``True`` means interaction not present in train.
Boolean mask of same length as `test_users` (`test_items`).
``True`` means interaction not present in train.
"""
if train_users.size != train_items.size:
raise ValueError("Lengths of `train_users` and `train_items` must be the same")
Expand Down
Loading

0 comments on commit 9b3992e

Please sign in to comment.