Skip to content

Commit

Permalink
Merge pull request #36 from gsbDBI/34-bug-fix-reshape_observable-meth…
Browse files Browse the repository at this point in the history
…od-in-bembpy

34 bug fix reshape observable method in bembpy
  • Loading branch information
TianyuDu authored Sep 15, 2023
2 parents c56c0b7 + 6e00d38 commit c0d6327
Show file tree
Hide file tree
Showing 4 changed files with 584 additions and 54 deletions.
71 changes: 46 additions & 25 deletions bemb/model/bemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,18 @@ def parse_utility(utility_string: str) -> List[Dict[str, Union[List[str], None]]
"""
# split additive terms
coefficient_suffix = ('_item', '_user', '_constant', '_category')
observable_prefix = ('item_', 'user_', 'session_', 'price_', 'taste_')
observable_prefix = ('item_', 'user_', 'session_',
# (user, item)-specific.
'useritem_', 'itemuser_', 'taste_',
# (user, session)-specific.
'usersession', 'sessionuser',
# (session, item)-specific.
'itemsession_', 'sessionitem_', 'price_',
# (user, session, item)-specific.
'usersessionitem_', 'useritemsession_',
'sessionuseritem_', 'sessionitemuser_',
'itemusersession_', 'itemsessionuser_',
)

def is_coefficient(name: str) -> bool:
return any(name.endswith(suffix) for suffix in coefficient_suffix)
Expand All @@ -71,18 +82,13 @@ def is_observable(name: str) -> bool:
atom = {'coefficient': [], 'observable': None}
# split multiplicative terms.
for x in term.split(' * '):
# Programmers can specify itemsession for price observables, this brings better intuition.
if x.startswith('itemsession_'):
# case 1: special observable name.
atom['observable'] = 'price_' + x[len('itemsession_'):]
elif is_observable(x):
# case 2: normal observable name.
assert not (is_observable(x) and is_coefficient(x)), f"The element {x} is ambiguous, it follows naming convention of both an observable and a coefficient."
if is_observable(x):
atom['observable'] = x
elif is_coefficient(x):
# case 3: normal coefficient name.
atom['coefficient'].append(x)
else:
# case 4: special coefficient name.
# case 3: special coefficient name.
# the _constant coefficient suffix is not intuitive enough, we allow none coefficient suffix for
# coefficient with constant value. For example, `lambda` is the same as `lambda_constant`.
warnings.warn(f'{x} term has no appropriate coefficient suffix or observable prefix, it is assumed to be a coefficient constant across all items, users, and sessions. If this is the desired behavior, you can also specify `{x}_constant` in the utility formula to avoid this warning message. The utility parser has replaced {x} term with `{x}_constant`.')
Expand Down Expand Up @@ -303,9 +309,6 @@ def __init__(self,
'user': num_user_obs,
'item': num_item_obs,
'category' : 0,
'session': num_session_obs,
'price': num_price_obs,
'taste': num_taste_obs,
'constant': 1 # not really used, for dummy variables.
}

Expand Down Expand Up @@ -853,25 +856,35 @@ def reshape_observable(obs, name):
# samples of coefficients.
O = obs.shape[-1] # number of observables.
assert O == positive_integer
if name.startswith('item_'):
if batch._is_item_attribute(name):
assert obs.shape == (I, O)
obs = obs.view(1, 1, I, O).expand(R, P, -1, -1)
elif name.startswith('user_'):
elif batch._is_user_attribute(name):
assert obs.shape == (U, O)
obs = obs[user_index, :] # (P, O)
obs = obs.view(1, P, 1, O).expand(R, -1, I, -1)
elif name.startswith('session_'):
elif batch._is_session_attribute(name):
assert obs.shape == (S, O)
obs = obs[session_index, :] # (P, O)
return obs.view(1, P, 1, O).expand(R, -1, I, -1)
elif name.startswith('price_'):
obs = obs.view(1, P, 1, O).expand(R, -1, I, -1)
elif batch._is_price_attribute(name):
assert obs.shape == (S, I, O)
obs = obs[session_index, :, :] # (P, I, O)
return obs.view(1, P, I, O).expand(R, -1, -1, -1)
elif name.startswith('taste_'):
obs = obs.view(1, P, I, O).expand(R, -1, -1, -1)
elif batch._is_useritem_attribute(name):
assert obs.shape == (U, I, O)
obs = obs[user_index, :, :] # (P, I, O)
return obs.view(1, P, I, O).expand(R, -1, -1, -1)
obs = obs.view(1, P, I, O).expand(R, -1, -1, -1)
elif batch._is_usersession_attribute(name):
assert obs.shape == (U, S, O)
obs = obs[user_index, session_index, :] # (P, O)
assert obs.shape == (P, O)
obs = obs.view(1, P, 1, O).expand(R, -1, I, -1)
elif batch._is_usersessionitem_attribute(name):
assert obs.shape == (U, S, I, O)
obs = obs[user_index, session_index, :, :] # (P, I, O)
assert obs.shape == (P, I, O)
obs = obs.view(1, P, I, O).expand(R, -1, -1, -1)
else:
raise ValueError
assert obs.shape == (R, P, I, O)
Expand Down Expand Up @@ -1056,6 +1069,8 @@ def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sa
U = self.num_users
I = self.num_items
NC = self.num_categories

assert len(user_index) == len(session_index) == len(relevant_item_index) == total_computation
# ==========================================================================================
# Helper Functions for Reshaping.
# ==========================================================================================
Expand All @@ -1082,21 +1097,27 @@ def reshape_observable(obs, name):
# samples of coefficients.
O = obs.shape[-1] # number of observables.
assert O == positive_integer
if name.startswith('item_'):
if batch._is_item_attribute(name):
assert obs.shape == (I, O)
obs = obs[relevant_item_index, :]
elif name.startswith('user_'):
elif batch._is_user_attribute(name):
assert obs.shape == (U, O)
obs = obs[user_index, :]
elif name.startswith('session_'):
elif batch._is_session_attribute(name):
assert obs.shape == (S, O)
obs = obs[session_index, :]
elif name.startswith('price_'):
elif batch._is_price_attribute(name):
assert obs.shape == (S, I, O)
obs = obs[session_index, relevant_item_index, :]
elif name.startswith('taste_'):
elif batch._is_useritem_attribute(name):
assert obs.shape == (U, I, O)
obs = obs[user_index, relevant_item_index, :]
elif batch._is_usersession_attribute(name):
assert obs.shape == (U, S, O)
obs = obs[user_index, session_index, :] # (total_computation, O)
elif batch._is_usersessionitem_attribute(name):
assert obs.shape == (U, S, I, O)
obs = obs[user_index, session_index, relevant_item_index, :]
else:
raise ValueError
assert obs.shape == (total_computation, O)
Expand Down
56 changes: 56 additions & 0 deletions tests/simulate_choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,59 @@ def simulate_dataset(num_users: int, num_items: int, data_size: int) -> List[Cho

dataset_list = [dataset[train_idx], dataset[val_idx], dataset[test_idx]]
return dataset_list


def simulate_dataset_v2(num_users: int, num_items: int, num_sessions: int, data_size: int) -> List[ChoiceDataset]:
length_of_dataset = data_size # $N$
# create observables/features, the number of parameters are arbitrarily chosen.
# generate 128 features for each user, e.g., race, gender.
# these variables should have shape (num_users, *)
user_obs = torch.randn(num_users, 128)
# generate 64 features for each user, e.g., quality.
item_obs = torch.randn(num_items, 64)
# generate 10 features for each session, e.g., weekday indicator.
session_obs = torch.randn(num_sessions, 10)
# generate 12 features for each session user pair, e.g., the budget of that user at the shopping day.
itemsession_obs = torch.randn(num_sessions, num_items, 12)
# generate 12 features for each user item pair, e.g., the user's preference on that item.
useritem_obs = torch.randn(num_users, num_items, 12)
# generate 10 user-session specific observables, e.g., the historical spending amount of that user at that session.
usersession_obs = torch.randn(num_users, num_sessions, 10)
# generate 8 features for each user session item triple, e.g., the user's preference on that item at that session.
# since `U*S*I` is potentially huge and may cause identifiability issues, we rarely use this kind of observable in practice.
usersessionitem_obs = torch.randn(num_users, num_sessions, num_items, 8)

# generate the array of item[n].
item_index = torch.LongTensor(np.random.choice(num_items, size=length_of_dataset))
# generate the array of user[n].
user_index = torch.LongTensor(np.random.choice(num_users, size=length_of_dataset))
# generate the array of session[n].
session_index = torch.LongTensor(np.random.choice(num_sessions, size=length_of_dataset))

# assume all items are available in all sessions.
item_availability = torch.ones(num_sessions, num_items).bool()

dataset = ChoiceDataset(
# pre-specified keywords of __init__
item_index=item_index, # required.
num_items=num_items,
# optional:
user_index=user_index,
num_users=num_users,
session_index=session_index,
item_availability=item_availability,
# additional keywords of __init__
user_obs=user_obs,
item_obs=item_obs,
session_obs=session_obs,
itemsession_obs=itemsession_obs,
useritem_obs=useritem_obs,
usersession_obs=usersession_obs,
usersessionitem_obs=usersessionitem_obs)

# we can subset the dataset by conventional python indexing.
dataset_train = dataset[:int(0.8*len(dataset))]
dataset_val = dataset[int(0.8*len(dataset)):int(0.9*len(dataset))]
dataset_test = dataset[int(0.9*len(dataset)):]

return [dataset_train, dataset_val, dataset_test]
93 changes: 64 additions & 29 deletions tests/test_bemb_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
global numUs, num_items, data_size
num_users = 50
num_items = 100
num_sessions = 500
data_size = 10000
num_seeds = 32

Expand All @@ -35,35 +36,13 @@ def test_parser_and_model_creation(self):
self.assertTrue(additive_decomposition[5]['coefficient'] == ['delta_item'] and additive_decomposition[5]['observable'] == 'item_obs')
self.assertTrue(additive_decomposition[6]['coefficient'] == ['comp_user', 'comp_item'] and additive_decomposition[6]['observable'] == 'user_obs')

def test_parser_price_and_itemsession_obs_1(self):
formula_1 = 'alpha_item * price_obs'
formula_2 = 'alpha_item * itemsession_obs'
u1 = parse_utility(formula_1)
u2 = parse_utility(formula_2)
self.assertTrue(u1 == u2)

def test_parser_price_and_itemsession_obs_2(self):
formula_1 = 'alpha_item * price_obs + user_obs * gamma_user + item_obs * delta_item'
formula_2 = 'alpha_item * itemsession_obs + user_obs * gamma_user + item_obs * delta_item'
u1 = parse_utility(formula_1)
u2 = parse_utility(formula_2)
self.assertTrue(u1 == u2)

def test_parser_constant_and_null_coef(self):
formula_1 = 'user_obs * gamma_constant + alpha_item * beta_item'
formula_2 = 'user_obs * gamma + alpha_item * beta_item'
u1 = parse_utility(formula_1)
u2 = parse_utility(formula_2)
self.assertTrue(u1 == u2)

def test_parser(self):
formula_1 = 'price_obs * gamma_constant + alpha_item * beta_item'
formula_2 = 'itemsession_obs * gamma + alpha_item * beta_item'

u1 = parse_utility(formula_1)
u2 = parse_utility(formula_2)
self.assertTrue(u1 == u2)

class TestBEMBFlex(unittest.TestCase):
"""
Testing core functionality of bemb.
Expand Down Expand Up @@ -134,20 +113,76 @@ def test_predict_proba_shape(self):
dataset_list = simulate_choice_dataset.simulate_dataset(num_users=num_users, num_items=num_items, data_size=data_size)
batch = dataset_list[-1]

for pred_item in [True, False]:

class TestBEMBFlexV2(unittest.TestCase):
"""
Testing core functionality of bemb.
"""
# def __init__(self):
# pass

# def test_initialization(self):
# pass

# def test_estimation(self):
# pass

# ==================================================================================================================
# Test Arguments and Options in the Initialization Method
# ==================================================================================================================
def test_init(self):
pass

def test_H_zero_mask(self):
pass

# ==================================================================================================================
# Test API Methods
# ==================================================================================================================
def test_prediction_shapes(self):
dataset_list = simulate_choice_dataset.simulate_dataset_v2(num_users=num_users, num_items=num_items, num_sessions=num_sessions, data_size=data_size)
batch = dataset_list[0]
# test different variations of the forward function.
# return_type X return_scope X deterministic X pred_items.
for return_type, return_scope, deterministic, pred_item in itertools.product(['utility', 'log_prob'], ['item_index', 'all_items'], [True, False], [True, False]):
if not pred_item:
# generate fake binary labels.
batch.label = torch.LongTensor(np.random.randint(2, size=len(batch)))

# initialize the model.
bemb = BEMBFlex(
pred_item=pred_item,
utility_formula='theta_user * alpha_item',
utility_formula="a_user + b_item + c_constant + d_user * e_item + f1_constant * user_obs + f2_constant * item_obs + f3_constant * session_obs + f4_constant * useritem_obs + f5_constant * usersession_obs + f6_constant * itemsession_obs + f7_constant * usersessionitem_obs",
num_users=num_users,
num_items=num_items,
num_sessions=num_sessions,
num_classes=None if pred_item else 2,
num_user_obs=dataset_list[0].user_obs.shape[1],
num_item_obs=dataset_list[0].item_obs.shape[1],
obs2prior_dict={'theta_user': True, 'alpha_item': True},
coef_dim_dict={'theta_user': 10, 'alpha_item': 10}
num_user_obs=batch.user_obs.shape[1],
num_item_obs=batch.item_obs.shape[1],
obs2prior_dict={'a_user': True, 'b_item': True, 'c_constant': False, 'd_user': True, 'e_item': True,
'f1_constant': False, 'f2_constant': False, 'f3_constant': False, 'f4_constant': False, 'f5_constant': False, 'f6_constant': False, 'f7_constant': False},
coef_dim_dict={'a_user': 1, 'b_item': 1, 'c_constant': 1, 'd_user': 10, 'e_item': 10,
'f1_constant': batch.user_obs.shape[-1], 'f2_constant': batch.item_obs.shape[-1], 'f3_constant': batch.session_obs.shape[-1],
'f4_constant': batch.useritem_obs.shape[-1] , 'f5_constant': batch.usersession_obs.shape[-1], 'f6_constant': batch.itemsession_obs.shape[-1], 'f7_constant': batch.usersessionitem_obs.shape[-1]}
)
P = bemb.predict_proba(batch)

output = bemb.forward(batch,
return_type=return_type, return_scope=return_scope,
deterministic=deterministic,
sample_dict=None,
num_seeds=num_seeds)

if (return_scope == 'item_index') and (deterministic == True):
self.assertEqual(output.shape, (len(batch),))
elif (return_scope == 'all_items') and (deterministic == True):
self.assertEqual(output.shape, (len(batch), num_items))
elif (return_scope == 'item_index') and (deterministic == False):
self.assertEqual(output.shape, (num_seeds, len(batch)))
elif (return_scope == 'item_index') and (deterministic == False):
self.assertEqual(output.shape, (num_seeds, len(batch), num_items))

# test predict_proba method.
P = bemb.predict_proba(batch)
if pred_item:
self.assertEqual(P.shape, (len(batch), num_items))
else:
Expand Down
Loading

0 comments on commit c0d6327

Please sign in to comment.