Skip to content

Commit

Permalink
add logit_components, price_coeff
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Nov 30, 2023
1 parent 0431ad4 commit 38abd25
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions bemb/model/bemb_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sam
batch.session_index = batch.session_index.repeat_interleave(self.num_items)
return self.log_likelihood_item_index(batch, return_logit, sample_dict, all_items=True)

def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False) -> torch.Tensor:
def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False, return_logit_components=False, return_price_coeff=False) -> torch.Tensor:
"""
NOTE for developers:
This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each
Expand All @@ -841,6 +841,10 @@ def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sa
The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim)
where num_classes in {num_users, num_items, 1}
and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}.
all_items: return for all items
debug: debug param, keeps evolving, see code
return_logit_components: return a tensor of size (len(batch), num_terms) where num_terms is the number of additive terms in self.formula. If this is set to True, return_logit must be set to True too
Returns:
torch.Tensor: a tensor of shape (num_seeds, len(batch)), where
Expand Down Expand Up @@ -976,9 +980,20 @@ def reshape_observable(obs, name):
# Compute Components related to users and items only.
# ==========================================================================================
utility = torch.zeros(R, total_computation, device=self.device)
if return_logit_components:
assert R == 1, 'return_logit_components is not supported for R > 1'
assert return_logit, "return_logit_components requires return_logit"
utility_components = torch.zeros(len(self.formula), R, total_computation, device=self.device)

if return_price_coeff:
# 'price_obs' needs to be seen in self.formula exactly once in self.utility_formula
assert R == 1, 'return_price_coeff is not supported for R > 1'
assert self.utility_formula.count('price_obs') == 1, "price_obs needs to be seen in self.formula exactly once for return_price_coeff"
price_coeffs = torch.zeros(R, total_computation, device=self.device)

# loop over additive term to utility
for term in self.formula:
for ii, term in enumerate(self.formula):
obs_coeff = None
if debug:
breakpoint()
# Type I: single coefficient, e.g., lambda_item or lambda_user.
Expand Down Expand Up @@ -1018,6 +1033,10 @@ def reshape_observable(obs, name):
obs = reshape_observable(getattr(batch, obs_name), obs_name)
assert obs.shape == (R, total_computation, positive_integer)

if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'):
obs_coeff = coef_sample.sum(dim=-1)
price_coeffs = obs_coeff

additive_term = (coef_sample * obs).sum(dim=-1)

# Type IV: factorized coefficient multiplied by observable.
Expand Down Expand Up @@ -1047,13 +1066,19 @@ def reshape_observable(obs, name):
# compute the factorized coefficient with shape (R, P, I, O).
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'):
obs_coeff = coef.sum(dim=-1)
price_coeffs = obs_coeff

additive_term = (coef * obs).sum(dim=-1)

else:
raise ValueError(f'Undefined term type: {term}')

assert additive_term.shape == (R, total_computation)
utility += additive_term
if return_logit_components:
utility_components[ii] = additive_term

# ==========================================================================================
# Mask Out Unavailable Items in Each Session.
Expand All @@ -1064,9 +1089,14 @@ def reshape_observable(obs, name):
A = batch.item_availability[session_index, relevant_item_index].unsqueeze(
dim=0).expand(R, -1)
utility[~A] = - (torch.finfo(utility.dtype).max / 2)
if return_logit_components:
utility_components[:, ~A] = - (torch.finfo(utility.dtype).max / 2)
if return_price_coeff:
price_coeffs[~A] = 0

for module in self.additional_modules:
# current utility shape: (R, total_computation)
assert False, "additional modules not supported for bemb_chunked"
additive_term = module(batch)
assert additive_term.shape == (
R, len(batch)) or additive_term.shape == (R, len(batch), 1)
Expand All @@ -1082,9 +1112,19 @@ def reshape_observable(obs, name):

if return_logit:
# (num_seeds, len(batch))
u = utility[:, item_index_expanded == relevant_item_index]
assert u.shape == (R, len(batch))
return u
if return_logit_components:
u = utility_components[:, :, item_index_expanded == relevant_item_index]
assert u.shape == (len(self.formula), R, len(batch))
return u
else:
u = utility[:, item_index_expanded == relevant_item_index]
assert u.shape == (R, len(batch))
if return_price_coeff:
price_coeffs = price_coeffs[ :, item_index_expanded == relevant_item_index].squeeze(dim=0)
assert price_coeffs.shape[0] == len(batch)
return u, price_coeffs
else:
return u

if self.pred_item:
# compute log likelihood log p(choosing item i | user, item latents)
Expand Down

0 comments on commit 38abd25

Please sign in to comment.