Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Aug 1, 2023
1 parent a50a245 commit 0431ad4
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions bemb/model/bemb_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ def forward(self, batch: ChoiceDataset,
return_scope: str,
deterministic: bool = True,
sample_dict: Optional[Dict[str, torch.Tensor]] = None,
num_seeds: Optional[int] = None
num_seeds: Optional[int] = None,
debug=False,
) -> torch.Tensor:
"""A combined method for inference with the model.
Expand Down Expand Up @@ -607,7 +608,7 @@ def forward(self, batch: ChoiceDataset,
elif return_scope == 'item_index':
# (num_seeds, len(batch))
out = self.log_likelihood_item_index(
batch=batch, sample_dict=sample_dict, return_logit=return_logit)
batch=batch, sample_dict=sample_dict, return_logit=return_logit, debug=debug)

if deterministic:
# drop the first dimension, which has size of `num_seeds` (equals 1 in the deterministic case).
Expand Down Expand Up @@ -817,7 +818,7 @@ def log_likelihood_all_items(self, batch: ChoiceDataset, return_logit: bool, sam
batch.session_index = batch.session_index.repeat_interleave(self.num_items)
return self.log_likelihood_item_index(batch, return_logit, sample_dict, all_items=True)

def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False) -> torch.Tensor:
def log_likelihood_item_index(self, batch: ChoiceDataset, return_logit: bool, sample_dict: Dict[str, torch.Tensor], all_items: bool=False, debug=False) -> torch.Tensor:
"""
NOTE for developers:
This method is more efficient and only computes log-likelihood/logit(utility) for item in item_index[i] for each
Expand Down Expand Up @@ -978,6 +979,8 @@ def reshape_observable(obs, name):

# loop over additive term to utility
for term in self.formula:
if debug:
breakpoint()
# Type I: single coefficient, e.g., lambda_item or lambda_user.
if len(term['coefficient']) == 1 and term['observable'] is None:
# E.g., lambda_item or lambda_user
Expand Down Expand Up @@ -1087,6 +1090,7 @@ def reshape_observable(obs, name):
# compute log likelihood log p(choosing item i | user, item latents)
# compute the log probability from logits/utilities.
# output shape: (num_seeds, len(batch), num_items)
# breakpoint()
log_p = scatter_log_softmax(utility, reverse_indices, dim=-1)
# select the log-P of the item actually bought.
log_p = log_p[:, item_index_expanded == relevant_item_index]
Expand Down Expand Up @@ -1234,6 +1238,8 @@ def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor:
# 1. sample latent variables from their variational distributions.
# ==============================================================================================================
if self.deterministic_variational:
sample_dict = self.sample_coefficient_dictionary(num_seeds)
'''
num_seeds = 1
# Use the means of variational distributions as the sole deterministic MC sample.
# NOTE: here we don't need to sample the obs2prior weight H since we only compute the log-likelihood.
Expand All @@ -1242,6 +1248,7 @@ def elbo(self, batch: ChoiceDataset, num_seeds: int = 1) -> torch.Tensor:
for coef_name, coef in self.coef_dict.items():
sample_dict[coef_name] = coef.variational_distribution.mean.unsqueeze(
dim=0) # (1, num_*, dim)
'''
else:
sample_dict = self.sample_coefficient_dictionary(num_seeds)

Expand Down

0 comments on commit 0431ad4

Please sign in to comment.