diff --git a/bemb/model/bemb_chunked.py b/bemb/model/bemb_chunked.py index 06b9396..c2f563c 100644 --- a/bemb/model/bemb_chunked.py +++ b/bemb/model/bemb_chunked.py @@ -499,7 +499,8 @@ def forward(self, batch: ChoiceDataset, return_scope: str, deterministic: bool = True, sample_dict: Optional[Dict[str, torch.Tensor]] = None, - num_seeds: Optional[int] = None + num_seeds: Optional[int] = None, + debug=False, ) -> torch.Tensor: """A combined method for inference with the model. @@ -607,7 +608,7 @@ def forward(self, batch: ChoiceDataset, elif return_scope == 'item_index': # (num_seeds, len(batch)) out = self.log_likelihood_item_index( - batch=batch, sample_dict=sample_dict, return_logit=return_logit) + batch=batch, sample_dict=sample_dict, return_logit=return_logit, debug=debug) if deterministic: # drop the first dimension, which has size of `num_seeds` (equals 1 in the deterministic case). @@ -817,7 +818,7 @@ def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sam batch.session_index = batch.session_index.repeat_interleave(self.num_items) return self.log_likelihood_item_index(batch, return_logit, sample_dict, all_items=True) - def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False) -> torch.Tensor: + def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False) -> torch.Tensor: """ NOTE for developers: This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each @@ -978,6 +979,8 @@ def reshape_observable(obs, name): # loop over additive term to utility for term in self.formula: + if debug: + breakpoint() # Type I: single coefficient, e.g., lambda_item or lambda_user. if len(term['coefficient']) == 1 and term['observable'] is None: # E.g., lambda_item or lambda_user @@ -1087,6 +1090,7 @@ def reshape_observable(obs, name): # compute log likelihood log p(choosing item i | user, item latents) # compute the log probability from logits/utilities. # output shape: (num_seeds, len(batch), num_items) + # breakpoint() log_p = scatter_log_softmax(utility, reverse_indices, dim=-1) # select the log-P of the item actually bought. log_p = log_p[:, item_index_expanded == relevant_item_index] @@ -1234,6 +1238,8 @@ def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor: # 1. sample latent variables from their variational distributions. # ============================================================================================================== if self.deterministic_variational: + sample_dict = self.sample_coefficient_dictionary(num_seeds) + ''' num_seeds = 1 # Use the means of variational distributions as the sole deterministic MC sample. # NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood. @@ -1242,6 +1248,7 @@ def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor: for coef_name, coef in self.coef_dict.items(): sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze( dim=0) # (1, num_*, dim) + ''' else: sample_dict = self.sample_coefficient_dictionary(num_seeds)