Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REFACTOR: Sequential Dataset & DataLoader #873

Merged
merged 13 commits into from
Jul 10, 2021
1 change: 0 additions & 1 deletion recbole/data/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from recbole.data.dataloader.general_dataloader import *
from recbole.data.dataloader.context_dataloader import *
from recbole.data.dataloader.sequential_dataloader import *
from recbole.data.dataloader.dien_dataloader import *
from recbole.data.dataloader.knowledge_dataloader import *
from recbole.data.dataloader.decisiontree_dataloader import *
from recbole.data.dataloader.user_dataloader import *
146 changes: 0 additions & 146 deletions recbole/data/dataloader/dien_dataloader.py

This file was deleted.

122 changes: 12 additions & 110 deletions recbole/data/dataloader/sequential_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE
# @Time : 2020/10/6, 2020/9/17
# @Time : 2021/7/8, 2020/9/17
# @Author : Yupeng Hou, Yushuo Chen
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand All @@ -15,15 +15,15 @@
import numpy as np
import torch

from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbole.data.dataloader.general_dataloader import GeneralDataLoader
from recbole.data.dataloader.neg_sample_mixin import NegSampleByMixin, NegSampleMixin
from recbole.data.interaction import Interaction, cat_interactions
from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType
from recbole.utils import DataLoaderType, InputType


class SequentialDataLoader(AbstractDataLoader):
""":class:`SequentialDataLoader` is used for sequential model. It will do data augmentation for the origin data.
And its returned data contains the following:
class SequentialDataLoader(GeneralDataLoader):
""":class:`SequentialDataLoader` is used for sequential model.
It contains the following:
- user id
- history items list
Expand All @@ -41,109 +41,7 @@ class SequentialDataLoader(AbstractDataLoader):
:obj:`~recbole.utils.enum_type.InputType.POINTWISE`.
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
"""
dl_type = DataLoaderType.ORIGIN

def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
self.uid_field = dataset.uid_field
self.iid_field = dataset.iid_field
self.time_field = dataset.time_field
self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH']

list_suffix = config['LIST_SUFFIX']
for field in dataset.inter_feat:
if field != self.uid_field:
list_field = field + list_suffix
setattr(self, f'{field}_list_field', list_field)
ftype = dataset.field2type[field]

if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]:
list_ftype = FeatureType.TOKEN_SEQ
else:
list_ftype = FeatureType.FLOAT_SEQ

if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]:
list_len = (self.max_item_list_len, dataset.field2seqlen[field])
else:
list_len = self.max_item_list_len

dataset.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len)

self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD']
dataset.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1)

self.uid_list = dataset.uid_list
self.item_list_index = dataset.item_list_index
self.target_index = dataset.target_index
self.item_list_length = dataset.item_list_length
self.pre_processed_data = None

super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)

def data_preprocess(self):
"""Do data augmentation before training/evaluation.
"""
self.pre_processed_data = self.augmentation(self.item_list_index, self.target_index, self.item_list_length)

@property
def pr_end(self):
return len(self.uid_list)

def _shuffle(self):
if self.real_time:
new_index = torch.randperm(self.pr_end)
self.uid_list = self.uid_list[new_index]
self.item_list_index = self.item_list_index[new_index]
self.target_index = self.target_index[new_index]
self.item_list_length = self.item_list_length[new_index]
else:
self.pre_processed_data.shuffle()

def _next_batch_data(self):
cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
self.pr += self.step
return cur_data

def _get_processed_data(self, index):
if self.real_time:
cur_data = self.augmentation(
self.item_list_index[index], self.target_index[index], self.item_list_length[index]
)
else:
cur_data = self.pre_processed_data[index]
return cur_data

def augmentation(self, item_list_index, target_index, item_list_length):
"""Data augmentation.
Args:
item_list_index (numpy.ndarray): the index of history items list in interaction.
target_index (numpy.ndarray): the index of items to be predicted in interaction.
item_list_length (numpy.ndarray): history list length.
Returns:
dict: the augmented data.
"""
new_length = len(item_list_index)
new_data = self.dataset.inter_feat[target_index]
new_dict = {
self.item_list_length_field: torch.tensor(item_list_length),
}

for field in self.dataset.inter_feat:
if field != self.uid_field:
list_field = getattr(self, f'{field}_list_field')
list_len = self.dataset.field2seqlen[list_field]
shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len
list_ftype = self.dataset.field2type[list_field]
dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64
new_dict[list_field] = torch.zeros(shape, dtype=dtype)

value = self.dataset.inter_feat[field]
for i, (index, length) in enumerate(zip(item_list_index, item_list_length)):
new_dict[list_field][i][:length] = value[index]

new_data.update(Interaction(new_dict))
return new_data
pass


class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader):
Expand All @@ -167,6 +65,9 @@ class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader):
def __init__(
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
):
self.uid_field = dataset.uid_field
self.iid_field = dataset.iid_field
self.label_field = dataset.label_field
super().__init__(
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
)
Expand All @@ -178,7 +79,7 @@ def _batch_size_adaptation(self):
self.upgrade_batch_size(new_batch_size)

def _next_batch_data(self):
cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
cur_data = self.dataset[self.pr:self.pr + self.step]
cur_data = self._neg_sampling(cur_data)
self.pr += self.step

Expand Down Expand Up @@ -253,6 +154,7 @@ class SequentialFullDataLoader(NegSampleMixin, SequentialDataLoader):
def __init__(
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
):
self.iid_field = dataset.iid_field
super().__init__(
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
)
Expand Down
Loading