Skip to content

Commit

Permalink
gamma coeff implemented; unclean right now
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Dec 6, 2023
1 parent 51ba2f1 commit 4b47bc2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 77 deletions.
87 changes: 24 additions & 63 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ def __init__(self,
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution == 'gamma':
assert not obs2prior, 'Gamma distribution is not supported for obs2prior at present.'
mean = 1
variance = 10
mean = 1.0
variance = 10.0
assert mean > 0, 'Gamma distribution requires mean > 0'
assert variance > 0, 'Gamma distribution requires variance > 0'
prior_mean = mean**2 / variance
prior_variance = mean / variance
shape = mean ** 2 / variance
rate = mean / variance
prior_mean = shape
prior_variance = rate

self.distribution = distribution

Expand Down Expand Up @@ -132,21 +134,21 @@ def __init__(self,
num_classes, dim) * self.prior_variance)

# create variational distribution.
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)

if self.distribution == 'gaussian':
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
# TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way.
'''
if self.distribution == 'gamma':
# take absolute value of the variational mean.
self.variational_mean_flexible.data = torch.abs(
self.variational_mean_flexible.data)
'''
elif self.distribution == 'gamma':
# initialize using uniform distribution between 0.5 and 1.5
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
self.variational_mean_flexible = nn.Parameter(
torch.rand(num_classes, dim) + 0.5, requires_grad=True)

if self.is_H and self.H_zero_mask is not None:
assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \
f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} "

# for gamma distribution, we store the rate as log(rate) = variational_logstd
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)

Expand Down Expand Up @@ -190,7 +192,8 @@ def variational_mean(self) -> torch.Tensor:
M = self.variational_mean_fixed + self.variational_mean_flexible

if self.distribution == 'gamma':
M = torch.pow(M, 2) + 0.000001
# M = torch.pow(M, 2) + 0.000001
M = M.exp() / self.variational_logstd.exp()

if self.is_H and (self.H_zero_mask is not None):
# a H-variable with zero-entry restriction.
Expand Down Expand Up @@ -246,44 +249,17 @@ def log_prior(self,
mu = self.prior_zero_mean

if self.distribution == 'gaussian':
# DEBUG_MARKER
'''
print('sample.shape', sample.shape)
print('gaussian')
print("mu.shape, self.prior_cov_diag.shape")
print(mu.shape, self.prior_cov_diag.shape)
'''
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'gamma':
concentration = torch.pow(mu, 2)/self.prior_cov_diag
rate = mu/self.prior_cov_diag
# DEBUG_MARKER
'''
print('sample.shape', sample.shape)
print('gamma')
print("mu.shape, self.prior_cov_diag.shape")
print(mu.shape, self.prior_cov_diag.shape)
print("concentration.shape, rate.shape")
print(concentration.shape, rate.shape)
'''
concentration = mu
rate = self.prior_variance
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
# sum over the last dimension
out = torch.sum(out, dim=-1)


# DEBUG_MARKER
'''
print("sample.shape")
print(sample.shape)
print("out.shape")
print(out.shape)
print("num_seeds, num_classes")
print(num_seeds, num_classes)
breakpoint()
'''
assert out.shape == (num_seeds, num_classes)
return out

Expand Down Expand Up @@ -321,14 +297,6 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
value_sample = self.variational_distribution.rsample(
torch.Size([num_seeds]))
# DEBUG_MARKER
'''
print("rsample")
print(self.distribution)
print("value_sample.shape")
print(value_sample.shape)
breakpoint()
'''
# DEBUG_MARKER
if self.obs2prior:
# sample obs2prior H as well.
H_sample = self.prior_H.rsample(num_seeds=num_seeds)
Expand All @@ -345,18 +313,11 @@ def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
elif self.distribution == 'gamma':
# concentration is mean**2 / var (std**2)
concentration = torch.pow(self.variational_mean, 2)/torch.pow(torch.exp(self.variational_logstd), 2)
# rate is mean / var (std**2)
rate = self.variational_mean/torch.pow(torch.exp(self.variational_logstd), 2)
# DEBUG_MARKER
'''
print("self.variational_mean, self.variational_logstd")
print(self.variational_mean, self.variational_logstd)
print("concentration, rate")
print(concentration, rate)
'''
# DEBUG_MARKER
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean'
concentration = self.variational_mean_flexible.exp()
# for gamma distribution, we store the rate as log(rate) = variational_logstd
rate = torch.exp(self.variational_logstd)
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")
Expand Down
33 changes: 22 additions & 11 deletions bemb/model/bemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,26 @@ def parse_utility(utility_string: str) -> List[Dict[str, Union[List[str], None]]
A helper function parse utility string into a list of additive terms.
Example:
utility_string = 'lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs'
utility_string = 'lambda_item + theta_user * alpha_item - gamma_user * beta_item * price_obs'
output = [
{
'coefficient': ['lambda_item'],
'observable': None
'observable': None,
'sign': 1.0,
},
{
'coefficient': ['theta_user', 'alpha_item'],
'observable': None
'sign': 1.0,
},
{
'coefficient': ['gamma_user', 'beta_item'],
'observable': 'price_obs'
'sign': -1.0,
}
]
Note that 'minus' is allowed in the utility string. If the first term is negative, the minus should be without a space.
"""
# split additive terms
coefficient_suffix = ('_item', '_user', '_constant', '_category')
Expand All @@ -76,10 +81,16 @@ def is_coefficient(name: str) -> bool:
def is_observable(name: str) -> bool:
return any(name.startswith(prefix) for prefix in observable_prefix)

utility_string = utility_string.replace(' - ', ' + -')
additive_terms = utility_string.split(' + ')
additive_decomposition = list()
for term in additive_terms:
atom = {'coefficient': [], 'observable': None}
if term.startswith('-'):
sign = -1.0
term = term[1:]
else:
sign = 1.0
atom = {'coefficient': [], 'observable': None, 'sign': sign}
# split multiplicative terms.
for x in term.split(' * '):
assert not (is_observable(x) and is_coefficient(x)), f"The element {x} is ambiguous, it follows naming convention of both an observable and a coefficient."
Expand Down Expand Up @@ -927,6 +938,7 @@ def reshape_observable(obs, name):
sample_dict[coef_name], coef_name)
assert coef_sample.shape == (R, P, I, 1)
additive_term = coef_sample.view(R, P, I)
additive_term *= term['sign']

# Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
elif len(term['coefficient']) == 2 and term['observable'] is None:
Expand All @@ -942,6 +954,7 @@ def reshape_observable(obs, name):
R, P, I, positive_integer)

additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)
additive_term *= term['sign']

# Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
elif len(term['coefficient']) == 1 and term['observable'] is not None:
Expand All @@ -955,8 +968,7 @@ def reshape_observable(obs, name):
assert obs.shape == (R, P, I, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0
additive_term *= term['sign']

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -987,8 +999,7 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0
additive_term *= term['sign']

else:
raise ValueError(f'Undefined term type: {term}')
Expand Down Expand Up @@ -1162,6 +1173,7 @@ def reshape_observable(obs, name):
sample_dict[coef_name], coef_name)
assert coef_sample.shape == (R, total_computation, 1)
additive_term = coef_sample.view(R, total_computation)
additive_term *= term['sign']

# Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
elif len(term['coefficient']) == 2 and term['observable'] is None:
Expand All @@ -1177,6 +1189,7 @@ def reshape_observable(obs, name):
R, total_computation, positive_integer)

additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)
additive_term *= term['sign']

# Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
elif len(term['coefficient']) == 1 and term['observable'] is not None:
Expand All @@ -1191,8 +1204,7 @@ def reshape_observable(obs, name):
assert obs.shape == (R, total_computation, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0
additive_term *= term['sign']

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -1222,8 +1234,7 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
if obs_name == 'price_obs':
additive_term *= -1.0
additive_term *= term['sign']

else:
raise ValueError(f'Undefined term type: {term}')
Expand Down
11 changes: 8 additions & 3 deletions tutorials/supermarket/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,16 @@ def load_tsv(file_name, data_dir):
bemb = run(bemb, dataset_list, batch_size=configs.batch_size, num_epochs=configs.num_epochs, run_test=False)

# '''
coeffs = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy()
# coeffs = coeffs**2
# give distribution statistics
print('Coefficients statistics:')
print(pd.DataFrame(coeffs).describe())
if 'gamma_user' in configs.utility:
coeffs_gamma = bemb.model.coef_dict['gamma_user'].variational_mean.detach().cpu().numpy()
print('Coefficients statistics Gamma:')
print(pd.DataFrame(coeffs_gamma).describe())
if 'nfact_category' in configs.utility:
coeffs_nfact = bemb.model.coef_dict['nfact_category'].variational_mean.detach().cpu().numpy()
print('Coefficients statistics nfact_category:')
print(pd.DataFrame(coeffs_nfact).describe())
# '''

# ==============================================================================================
Expand Down

0 comments on commit 4b47bc2

Please sign in to comment.