Skip to content

Commit

Permalink
Update prior handling with positive variationals
Browse files Browse the repository at this point in the history
1. the hyperprior is gaussian
2. the prior is gaussian
3. the lower bound of the lognormal is lowered
4. handling is changed to make it clearer
  • Loading branch information
kanodiaayush committed Jun 17, 2024
1 parent b07d90a commit 0e034de
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
54 changes: 28 additions & 26 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,28 @@ def __init__(self,
self.mean_clamp = (np.log(0.1), np.log(100.0))
self.logstd_clamp = (np.log(0.1), np.log(10000.0))
shape = prior_mean ** 2 / prior_variance
assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})'
rate = prior_mean / prior_variance
assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})'
# prior_mean stores ln(shape) for gamma
prior_mean = np.log(shape)
# prior_variance stores rate for gamma
prior_variance = rate
# prior_mean = np.log(prior_mean)
# prior_variance = prior_variance
if not obs2prior:
assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})'
assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})'
# prior_mean stores ln(shape) for gamma
prior_mean = np.log(shape)
# prior_variance stores rate for gamma
prior_variance = rate
# prior_mean = np.log(prior_mean)
# prior_variance = prior_variance

elif distribution == 'lognormal':
# mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2)
# prior_mean in -2, exp(3)
self.mean_clamp = (-2.0, np.exp(1.5))
self.mean_clamp = (-10.0, np.exp(1.5))
# sigma sq clamp exp(-20), exp(1.5)
# therefore sigma in (exp(-10), exp(0.75))
# therefore log sigma in (-10, 0.75)
self.logstd_clamp = (-10.0, 0.75)
assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}'
assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}'
if not obs2prior:
assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}'
assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}'
# assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}'
# assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}'

Expand Down Expand Up @@ -254,6 +256,7 @@ def variational_mean(self) -> torch.Tensor:

elif self.distribution == 'lognormal':
M = (torch.minimum((M.exp() + 0.000001), torch.tensor(1e3)))
# M = torch.minimum(M + 0.000001, torch.tensor(1e3))

if self.is_H and (self.H_zero_mask is not None):
# a H-variable with zero-entry restriction.
Expand Down Expand Up @@ -314,24 +317,23 @@ def log_prior(self,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'lognormal':
mu = torch.clamp(mu, min=-100.0, max=10.0)
out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample)
out = torch.sum(out, dim=-1)
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
# out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample)
# out = torch.sum(out, dim=-1)
# out = torch.zeros((num_seeds, num_classes), device=sample.device)

elif self.distribution == 'gamma':
mu = torch.clamp(mu, min=-100.0, max=4.0)
concentration = torch.exp(mu)
rate = self.prior_variance
'''
print(concentration)
print('concentration')
print(rate)
print('rate')
'''
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
# sum over the last dimension
out = torch.sum(out, dim=-1)
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
# concentration = torch.exp(mu)
# rate = self.prior_variance
# out = Gamma(concentration=concentration,
# rate=rate).log_prob(sample)
# out = torch.sum(out, dim=-1)

# out = torch.zeros((num_seeds, num_classes), device=sample.device)
assert out.shape == (num_seeds, num_classes)
Expand Down Expand Up @@ -443,7 +445,7 @@ def clamp_params(self) -> None:
self.variational_mean_flexible.data = torch.clamp(
self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1])
self.variational_logstd.data = torch.clamp(
self.variational_logstd.data, min=self.mean_clamp[0], max=self.logstd_clamp[1])
self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1])
# elif self.distribution == 'gamma':
# self.variational_mean_flexible.data = torch.clamp(
# self.variational_mean_flexible.data, min=-2.0, max=4)
Expand Down
19 changes: 10 additions & 9 deletions tutorials/supermarket/configs3_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ device: cuda
# data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/
# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/20180101-20191231_42/tsv
# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/2
data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/
# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/
# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1887/chunked_1
# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1
data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1
# utility: lambda_item
# utility: lambda_item + theta_user * alpha_item
# utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs
# utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
# utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs
# utility: lambda_item + theta_user * alpha_item - nfact_category * gamma_user * price_obs
utility: theta_user * alpha_item - nfact_category * gamma_user * price_obs
# utility: theta_user * alpha_item + nfact_category * gamma_user * price_obs
# utility: -nfact_category * gamma_user * price_obs
# utility: -nfact_category * gamma_user * price_obs
# utility: -gamma_user * price_obs
utility: -gamma_user * price_obs
# utility: lambda_item
out_dir: ./output/
# model configuration.
Expand All @@ -39,8 +39,9 @@ coef_dist_dict:
default: 'gaussian'
# gamma_user: 'gamma'
# nfact_category: 'gamma'
gamma_user: 'gaussian'
nfact_category: 'lognormal'
# gamma_user: 'gamma'
gamma_user: 'lognormal'
# nfact_category: 'gamma'
prior_mean:
default: 0.0
# mean is shape for gamma variable
Expand All @@ -53,8 +54,8 @@ prior_variance:
default: 100000.0
# variance is rate for gamma variable
# shape is mean / var
gamma_user: 10.0
nfact_category: 10.0
gamma_user: 2.0
nfact_category: 2.0
#### optimization.
trace_log_q: False
shuffle: False
Expand All @@ -63,7 +64,7 @@ num_epochs: 200
learning_rate: 0.03
num_mc_seeds: 1
num_price_obs: 1
obs_user: True
obs_user: False
obs_item: False
patience: 10
complete_availability: True

0 comments on commit 0e034de

Please sign in to comment.