diff --git a/bemb/model/bemb_chunked.py b/bemb/model/bemb_chunked.py index c2f563c..64eb511 100644 --- a/bemb/model/bemb_chunked.py +++ b/bemb/model/bemb_chunked.py @@ -818,7 +818,7 @@ def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sam batch.session_index = batch.session_index.repeat_interleave(self.num_items) return self.log_likelihood_item_index(batch, return_logit, sample_dict, all_items=True) - def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False) -> torch.Tensor: + def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False, return_logit_components=False, return_price_coeff=False) -> torch.Tensor: """ NOTE for developers: This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each @@ -841,6 +841,10 @@ def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sa The value of sample_dict should be tensors of shape (num_seeds, num_classes, dim) where num_classes in {num_users, num_items, 1} and dim in {latent_dim(K), num_item_obs, num_user_obs, 1}. + all_items: return for all items + debug: debug param, keeps evolving, see code + return_logit_components: return a tensor of size (len(batch), num_terms) where num_terms is the number of additive terms in self.formula. If this is set to True, return_logit must be set to True too + Returns: torch.Tensor: a tensor of shape (num_seeds, len(batch)), where @@ -976,9 +980,20 @@ def reshape_observable(obs, name): # Compute Components related to users and items only. # ========================================================================================== utility = torch.zeros(R, total_computation, device=self.device) + if return_logit_components: + assert R == 1, 'return_logit_components is not supported for R > 1' + assert return_logit, "return_logit_components requires return_logit" + utility_components = torch.zeros(len(self.formula), R, total_computation, device=self.device) + + if return_price_coeff: + # 'price_obs' needs to be seen in self.formula exactly once in self.utility_formula + assert R == 1, 'return_price_coeff is not supported for R > 1' + assert self.utility_formula.count('price_obs') == 1, "price_obs needs to be seen in self.formula exactly once for return_price_coeff" + price_coeffs = torch.zeros(R, total_computation, device=self.device) # loop over additive term to utility - for term in self.formula: + for ii, term in enumerate(self.formula): + obs_coeff = None if debug: breakpoint() # Type I: single coefficient, e.g., lambda_item or lambda_user. @@ -1018,6 +1033,10 @@ def reshape_observable(obs, name): obs = reshape_observable(getattr(batch, obs_name), obs_name) assert obs.shape == (R, total_computation, positive_integer) + if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): + obs_coeff = coef_sample.sum(dim=-1) + price_coeffs = obs_coeff + additive_term = (coef_sample * obs).sum(dim=-1) # Type IV: factorized coefficient multiplied by observable. @@ -1047,6 +1066,10 @@ def reshape_observable(obs, name): # compute the factorized coefficient with shape (R, P, I, O). coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) + if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): + obs_coeff = coef.sum(dim=-1) + price_coeffs = obs_coeff + additive_term = (coef * obs).sum(dim=-1) else: @@ -1054,6 +1077,8 @@ def reshape_observable(obs, name): assert additive_term.shape == (R, total_computation) utility += additive_term + if return_logit_components: + utility_components[ii] = additive_term # ========================================================================================== # Mask Out Unavailable Items in Each Session. @@ -1064,9 +1089,14 @@ def reshape_observable(obs, name): A = batch.item_availability[session_index, relevant_item_index].unsqueeze( dim=0).expand(R, -1) utility[~A] = - (torch.finfo(utility.dtype).max / 2) + if return_logit_components: + utility_components[:, ~A] = - (torch.finfo(utility.dtype).max / 2) + if return_price_coeff: + price_coeffs[~A] = 0 for module in self.additional_modules: # current utility shape: (R, total_computation) + assert False, "additional modules not supported for bemb_chunked" additive_term = module(batch) assert additive_term.shape == ( R, len(batch)) or additive_term.shape == (R, len(batch), 1) @@ -1082,9 +1112,19 @@ def reshape_observable(obs, name): if return_logit: # (num_seeds, len(batch)) - u = utility[:, item_index_expanded == relevant_item_index] - assert u.shape == (R, len(batch)) - return u + if return_logit_components: + u = utility_components[:, :, item_index_expanded == relevant_item_index] + assert u.shape == (len(self.formula), R, len(batch)) + return u + else: + u = utility[:, item_index_expanded == relevant_item_index] + assert u.shape == (R, len(batch)) + if return_price_coeff: + price_coeffs = price_coeffs[ :, item_index_expanded == relevant_item_index].squeeze(dim=0) + assert price_coeffs.shape[0] == len(batch) + return u, price_coeffs + else: + return u if self.pred_item: # compute log likelihood log p(choosing item i | user, item latents)