Skip to content

Commit

Permalink
Merge pull request #873 from hyp1231/data
Browse files Browse the repository at this point in the history
REFACTOR: Sequential Dataset & DataLoader
  • Loading branch information
2017pxy authored Jul 10, 2021
2 parents deea3a4 + 99c8d59 commit 36bffcf
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 402 deletions.
1 change: 0 additions & 1 deletion recbole/data/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from recbole.data.dataloader.general_dataloader import *
from recbole.data.dataloader.context_dataloader import *
from recbole.data.dataloader.sequential_dataloader import *
from recbole.data.dataloader.dien_dataloader import *
from recbole.data.dataloader.knowledge_dataloader import *
from recbole.data.dataloader.decisiontree_dataloader import *
from recbole.data.dataloader.user_dataloader import *
146 changes: 0 additions & 146 deletions recbole/data/dataloader/dien_dataloader.py

This file was deleted.

122 changes: 12 additions & 110 deletions recbole/data/dataloader/sequential_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE
# @Time : 2020/10/6, 2020/9/17
# @Time : 2021/7/8, 2020/9/17
# @Author : Yupeng Hou, Yushuo Chen
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand All @@ -15,15 +15,15 @@
import numpy as np
import torch

from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbole.data.dataloader.general_dataloader import GeneralDataLoader
from recbole.data.dataloader.neg_sample_mixin import NegSampleByMixin, NegSampleMixin
from recbole.data.interaction import Interaction, cat_interactions
from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType
from recbole.utils import DataLoaderType, InputType


class SequentialDataLoader(AbstractDataLoader):
""":class:`SequentialDataLoader` is used for sequential model. It will do data augmentation for the origin data.
And its returned data contains the following:
class SequentialDataLoader(GeneralDataLoader):
""":class:`SequentialDataLoader` is used for sequential model.
It contains the following:
- user id
- history items list
Expand All @@ -41,109 +41,7 @@ class SequentialDataLoader(AbstractDataLoader):
:obj:`~recbole.utils.enum_type.InputType.POINTWISE`.
shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.
"""
dl_type = DataLoaderType.ORIGIN

def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False):
self.uid_field = dataset.uid_field
self.iid_field = dataset.iid_field
self.time_field = dataset.time_field
self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH']

list_suffix = config['LIST_SUFFIX']
for field in dataset.inter_feat:
if field != self.uid_field:
list_field = field + list_suffix
setattr(self, f'{field}_list_field', list_field)
ftype = dataset.field2type[field]

if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]:
list_ftype = FeatureType.TOKEN_SEQ
else:
list_ftype = FeatureType.FLOAT_SEQ

if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]:
list_len = (self.max_item_list_len, dataset.field2seqlen[field])
else:
list_len = self.max_item_list_len

dataset.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len)

self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD']
dataset.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1)

self.uid_list = dataset.uid_list
self.item_list_index = dataset.item_list_index
self.target_index = dataset.target_index
self.item_list_length = dataset.item_list_length
self.pre_processed_data = None

super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)

def data_preprocess(self):
"""Do data augmentation before training/evaluation.
"""
self.pre_processed_data = self.augmentation(self.item_list_index, self.target_index, self.item_list_length)

@property
def pr_end(self):
return len(self.uid_list)

def _shuffle(self):
if self.real_time:
new_index = torch.randperm(self.pr_end)
self.uid_list = self.uid_list[new_index]
self.item_list_index = self.item_list_index[new_index]
self.target_index = self.target_index[new_index]
self.item_list_length = self.item_list_length[new_index]
else:
self.pre_processed_data.shuffle()

def _next_batch_data(self):
cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
self.pr += self.step
return cur_data

def _get_processed_data(self, index):
if self.real_time:
cur_data = self.augmentation(
self.item_list_index[index], self.target_index[index], self.item_list_length[index]
)
else:
cur_data = self.pre_processed_data[index]
return cur_data

def augmentation(self, item_list_index, target_index, item_list_length):
"""Data augmentation.
Args:
item_list_index (numpy.ndarray): the index of history items list in interaction.
target_index (numpy.ndarray): the index of items to be predicted in interaction.
item_list_length (numpy.ndarray): history list length.
Returns:
dict: the augmented data.
"""
new_length = len(item_list_index)
new_data = self.dataset.inter_feat[target_index]
new_dict = {
self.item_list_length_field: torch.tensor(item_list_length),
}

for field in self.dataset.inter_feat:
if field != self.uid_field:
list_field = getattr(self, f'{field}_list_field')
list_len = self.dataset.field2seqlen[list_field]
shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len
list_ftype = self.dataset.field2type[list_field]
dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64
new_dict[list_field] = torch.zeros(shape, dtype=dtype)

value = self.dataset.inter_feat[field]
for i, (index, length) in enumerate(zip(item_list_index, item_list_length)):
new_dict[list_field][i][:length] = value[index]

new_data.update(Interaction(new_dict))
return new_data
pass


class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader):
Expand All @@ -167,6 +65,9 @@ class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader):
def __init__(
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
):
self.uid_field = dataset.uid_field
self.iid_field = dataset.iid_field
self.label_field = dataset.label_field
super().__init__(
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
)
Expand All @@ -178,7 +79,7 @@ def _batch_size_adaptation(self):
self.upgrade_batch_size(new_batch_size)

def _next_batch_data(self):
cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step))
cur_data = self.dataset[self.pr:self.pr + self.step]
cur_data = self._neg_sampling(cur_data)
self.pr += self.step

Expand Down Expand Up @@ -253,6 +154,7 @@ class SequentialFullDataLoader(NegSampleMixin, SequentialDataLoader):
def __init__(
self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False
):
self.iid_field = dataset.iid_field
super().__init__(
config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle
)
Expand Down
Loading

0 comments on commit 36bffcf

Please sign in to comment.