Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit sequence weights and Implicit factorization weights #122

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6952630
MAINT: adding spotlight results files to gitignore
nikisix Jul 11, 2018
910fb9a
feature: sample weight implementation for loss functions and implicit…
nikisix Jul 17, 2018
f8a8409
lint: linting changes
nikisix Jul 17, 2018
788f369
partial: start modifying sequence interactions to incorporate sample_…
nikisix Jul 18, 2018
191356b
MAINT: commit based on pr 120 discussion. rename base_loss to _weight…
nikisix Jul 18, 2018
528bb48
Merge branch 'sample-weights' into seq-weights
nikisix Jul 18, 2018
0ba7164
LINT: BUG:
nikisix Jul 18, 2018
91a4d22
LINT: BUG:
nikisix Jul 18, 2018
3fa6246
FEATURE: sequence sample weights working
nikisix Jul 23, 2018
28faf31
MAINT: augmenting implicit_factorizers fit method to handle sample_we…
nikisix Jul 23, 2018
68ac9f2
Merge branch 'sample-weights' into seq-weights
nikisix Jul 23, 2018
b4a6a7d
FEAT: implicit sequence masking accomplished via a sample weight of zero
nikisix Jul 24, 2018
9bd923e
MAINT: combining Interactions.to_sequence and Interactions.to_weighte…
nikisix Jul 24, 2018
058edde
LINT: linting interactions and sequence/implicit
nikisix Jul 24, 2018
02e3dda
BUG: DOC: fix weight flow logic in sequence implicit, and add a doc s…
nikisix Jul 24, 2018
6bcdd16
MAINT: casting to cuda tensor if gpu is enabled
jwinemiller-aa Jul 31, 2018
731835b
MAINT: one more gpu cast
nikisix Jul 31, 2018
64af649
BUG: shuffling weight sequences alongside interactions. #122
nikisix Aug 6, 2018
0b3b5e2
MAINT: removing masks from losses
nikisix Aug 29, 2018
3905bc4
TEST: adding sequential weight tests for zeros, normal(ones), and hig…
nikisix Aug 31, 2018
43b32be
LINT: datasets/synthetic and tests/sequence/tests_sequence_weights
nikisix Sep 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# spotlight results files
*_results.txt

*~
*#*

Expand All @@ -18,3 +21,4 @@

# IDE
tags
cscope.out
26 changes: 24 additions & 2 deletions spotlight/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def generate_sequential(num_users=100,
num_interactions=10000,
concentration_parameter=0.1,
order=3,
random_state=None):
random_state=None,
weight_type=None
):
"""
Generate a dataset of user-item interactions where sequential
information matters.
Expand Down Expand Up @@ -100,6 +102,8 @@ def generate_sequential(num_users=100,
order of the Markov chain
random_state: numpy.random.RandomState, optional
random state used to generate the data
weight_type: string, optional
Must be: ones, zeros, or high

Returns
-------
Expand All @@ -108,6 +112,22 @@ def generate_sequential(num_users=100,
instance of the interactions class
"""

weights = None
weight_types = ['ones', 'zeros', 'high']
if weight_type is not None:
if weight_type not in weight_types:
raise ValueError(
"weight_type {} not in {}"
.format(weight_type, weight_types)
)
if weight_type == 'ones':
weights = np.ones(num_interactions)
elif weight_type == 'zeros':
weights = np.zeros(num_interactions)
elif weight_type == 'high':
large_weight = 1E9
weights = large_weight * np.ones(num_interactions)

if random_state is None:
random_state = np.random.RandomState()

Expand All @@ -132,4 +152,6 @@ def generate_sequential(num_users=100,
ratings=ratings,
timestamps=timestamps,
num_users=num_users,
num_items=num_items)
num_items=num_items,
weights=weights
)
47 changes: 36 additions & 11 deletions spotlight/factorization/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _check_input(self, user_ids, item_ids, allow_items_none=False):

def fit(self, interactions, verbose=False):
"""
Fit the model.
Fit the model using sample weights.

When called repeatedly, model fitting will resume from
the point at which training stopped in the previous fit
Expand All @@ -198,9 +198,11 @@ def fit(self, interactions, verbose=False):
verbose: bool
Output additional information about current epoch and loss.
"""

user_ids = interactions.user_ids.astype(np.int64)
item_ids = interactions.item_ids.astype(np.int64)
sample_weights = None
if interactions.weights is not None:
sample_weights = interactions.weights.astype(np.float32)

if not self._initialized:
self._initialize(interactions)
Expand All @@ -209,22 +211,41 @@ def fit(self, interactions, verbose=False):

for epoch_num in range(self._n_iter):

users, items = shuffle(user_ids,
item_ids,
random_state=self._random_state)
users, items, sample_weights = shuffle(
user_ids,
item_ids,
sample_weights,
random_state=self._random_state
)

user_ids_tensor = gpu(torch.from_numpy(users),
self._use_cuda)
item_ids_tensor = gpu(torch.from_numpy(items),
self._use_cuda)
sample_weights_tensor = None
if sample_weights is not None:
sample_weights_tensor = gpu(
torch.from_numpy(sample_weights),
self._use_cuda
)

epoch_loss = 0.0

for (minibatch_num,
(batch_user,
batch_item)) in enumerate(minibatch(user_ids_tensor,
item_ids_tensor,
batch_size=self._batch_size)):
for (
minibatch_num,
(
batch_user,
batch_item,
batch_sample_weights
)
) in enumerate(
minibatch(
user_ids_tensor,
item_ids_tensor,
sample_weights_tensor,
batch_size=self._batch_size
)
):

positive_prediction = self._net(batch_user, batch_item)

Expand All @@ -236,7 +257,11 @@ def fit(self, interactions, verbose=False):

self._optimizer.zero_grad()

loss = self._loss_func(positive_prediction, negative_prediction)
loss = self._loss_func(
positive_prediction,
negative_prediction,
sample_weights=batch_sample_weights
)
epoch_loss += loss.item()

loss.backward()
Expand Down
39 changes: 31 additions & 8 deletions spotlight/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _sliding_window(tensor, window_size, step_size=1):
yield tensor[max(i - window_size, 0):i]


def _generate_sequences(user_ids, item_ids,
def _generate_sequences(user_ids, sequence_elements,
indices,
max_sequence_length,
step_size):
Expand All @@ -28,7 +28,7 @@ def _generate_sequences(user_ids, item_ids,
else:
stop_idx = indices[i + 1]

for seq in _sliding_window(item_ids[start_idx:stop_idx],
for seq in _sliding_window(sequence_elements[start_idx:stop_idx],
max_sequence_length,
step_size):

Expand Down Expand Up @@ -63,7 +63,7 @@ class Interactions(object):
timestamps: array of np.int32, optional
array of timestamps
weights: array of np.float32, optional
array of weights
array of sample importance weights
num_users: int, optional
Number of distinct users in the dataset.
Must be larger than the maximum user id
Expand All @@ -85,7 +85,7 @@ class Interactions(object):
timestamps: array of np.int32, optional
array of timestamps
weights: array of np.float32, optional
array of weights
array of sample importance weights
num_users: int, optional
Number of distinct users in the dataset.
num_items: int, optional
Expand Down Expand Up @@ -218,7 +218,7 @@ def to_sequence(self, max_sequence_length=10, min_sequence_length=None, step_siz
sequence interactions: :class:`~SequenceInteractions`
The resulting sequence interactions.
"""

weighted = self.weights is not None
if self.timestamps is None:
raise ValueError('Cannot convert to sequences, '
'timestamps not available.')
Expand All @@ -236,6 +236,8 @@ def to_sequence(self, max_sequence_length=10, min_sequence_length=None, step_siz

user_ids = self.user_ids[sort_indices]
item_ids = self.item_ids[sort_indices]
if weighted:
weights = self.weights[sort_indices]

user_ids, indices, counts = np.unique(user_ids,
return_index=True,
Expand All @@ -245,6 +247,10 @@ def to_sequence(self, max_sequence_length=10, min_sequence_length=None, step_siz

sequences = np.zeros((num_subsequences, max_sequence_length),
dtype=np.int32)
weight_sequences = None
if weighted:
weight_sequences = np.zeros((num_subsequences, max_sequence_length),
dtype=np.int32)
sequence_users = np.empty(num_subsequences,
dtype=np.int32)
for i, (uid,
Expand All @@ -256,13 +262,25 @@ def to_sequence(self, max_sequence_length=10, min_sequence_length=None, step_siz
sequences[i][-len(seq):] = seq
sequence_users[i] = uid

if weighted:
for i, (uid,
seq) in enumerate(_generate_sequences(user_ids,
weights,
indices,
max_sequence_length,
step_size)):
weight_sequences[i][-len(seq):] = seq

if min_sequence_length is not None:
long_enough = sequences[:, -min_sequence_length] != 0
sequences = sequences[long_enough]
sequence_users = sequence_users[long_enough]
if weighted:
weight_sequences = weight_sequences[long_enough]

return (SequenceInteractions(sequences,
user_ids=sequence_users,
weight_sequences=weight_sequences,
num_items=self.num_items))


Expand All @@ -276,6 +294,11 @@ class SequenceInteractions(object):
sequences: array of np.int32 of shape (num_sequences x max_sequence_length)
The interactions sequence matrix, as produced by
:func:`~Interactions.to_sequence`
user_ids: array of np.int32, optional
user_id represented by a sequence of item_ids.
weight_sequences: array of np.int32 of shape
(num_sequences x max_sequence_length), optional.
Sequence of sample weights.
num_items: int, optional
The number of distinct items in the data

Expand All @@ -287,11 +310,11 @@ class SequenceInteractions(object):
:func:`~Interactions.to_sequence`
"""

def __init__(self,
sequences,
user_ids=None, num_items=None):
def __init__(self, sequences,
user_ids=None, weight_sequences=None, num_items=None):

self.sequences = sequences
self.weight_sequences = weight_sequences
self.user_ids = user_ids
self.max_sequence_length = sequences.shape[1]

Expand Down
Loading