From 0e034dea91e91c1d2cfe7be1ddb864b4e8089cf1 Mon Sep 17 00:00:00 2001 From: Ayush Kanodia Date: Mon, 17 Jun 2024 13:23:19 -0700 Subject: [PATCH] Update prior handling with positive variationals 1. the hyperprior is gaussian 2. the prior is gaussian 3. the lower bound of the lognormal is lowered 4. handling is changed to make it clearer --- bemb/model/bayesian_coefficient.py | 54 ++++++++++++++------------- tutorials/supermarket/configs3_1.yaml | 19 +++++----- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/bemb/model/bayesian_coefficient.py b/bemb/model/bayesian_coefficient.py index 0de89cb..0a6460f 100644 --- a/bemb/model/bayesian_coefficient.py +++ b/bemb/model/bayesian_coefficient.py @@ -89,26 +89,28 @@ def __init__(self, self.mean_clamp = (np.log(0.1), np.log(100.0)) self.logstd_clamp = (np.log(0.1), np.log(10000.0)) shape = prior_mean ** 2 / prior_variance - assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' rate = prior_mean / prior_variance - assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' - # prior_mean stores ln(shape) for gamma - prior_mean = np.log(shape) - # prior_variance stores rate for gamma - prior_variance = rate - # prior_mean = np.log(prior_mean) - # prior_variance = prior_variance + if not obs2prior: + assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' + assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' + # prior_mean stores ln(shape) for gamma + prior_mean = np.log(shape) + # prior_variance stores rate for gamma + prior_variance = rate + # prior_mean = np.log(prior_mean) + # prior_variance = prior_variance elif distribution == 'lognormal': # mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2) # prior_mean in -2, exp(3) - self.mean_clamp = (-2.0, np.exp(1.5)) + self.mean_clamp = (-10.0, np.exp(1.5)) # sigma sq clamp exp(-20), exp(1.5) # therefore sigma in (exp(-10), exp(0.75)) # therefore log sigma in (-10, 0.75) self.logstd_clamp = (-10.0, 0.75) - assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' - assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' + if not obs2prior: + assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' + assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' # assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}' # assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}' @@ -254,6 +256,7 @@ def variational_mean(self) -> torch.Tensor: elif self.distribution == 'lognormal': M = (torch.minimum((M.exp() + 0.000001), torch.tensor(1e3))) + # M = torch.minimum(M + 0.000001, torch.tensor(1e3)) if self.is_H and (self.H_zero_mask is not None): # a H-variable with zero-entry restriction. @@ -314,24 +317,23 @@ def log_prior(self, cov_diag=self.prior_cov_diag).log_prob(sample) elif self.distribution == 'lognormal': mu = torch.clamp(mu, min=-100.0, max=10.0) - out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) - out = torch.sum(out, dim=-1) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) + # out = torch.sum(out, dim=-1) # out = torch.zeros((num_seeds, num_classes), device=sample.device) elif self.distribution == 'gamma': mu = torch.clamp(mu, min=-100.0, max=4.0) - concentration = torch.exp(mu) - rate = self.prior_variance - ''' - print(concentration) - print('concentration') - print(rate) - print('rate') - ''' - out = Gamma(concentration=concentration, - rate=rate).log_prob(sample) - # sum over the last dimension - out = torch.sum(out, dim=-1) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # concentration = torch.exp(mu) + # rate = self.prior_variance + # out = Gamma(concentration=concentration, + # rate=rate).log_prob(sample) + # out = torch.sum(out, dim=-1) # out = torch.zeros((num_seeds, num_classes), device=sample.device) assert out.shape == (num_seeds, num_classes) @@ -443,7 +445,7 @@ def clamp_params(self) -> None: self.variational_mean_flexible.data = torch.clamp( self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) self.variational_logstd.data = torch.clamp( - self.variational_logstd.data, min=self.mean_clamp[0], max=self.logstd_clamp[1]) + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) # elif self.distribution == 'gamma': # self.variational_mean_flexible.data = torch.clamp( # self.variational_mean_flexible.data, min=-2.0, max=4) diff --git a/tutorials/supermarket/configs3_1.yaml b/tutorials/supermarket/configs3_1.yaml index 5a3971d..c2bbb7e 100644 --- a/tutorials/supermarket/configs3_1.yaml +++ b/tutorials/supermarket/configs3_1.yaml @@ -3,19 +3,19 @@ device: cuda # data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/ # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/20180101-20191231_42/tsv # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/2 -data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/ +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/ # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1887/chunked_1 -# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1 +data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1 # utility: lambda_item # utility: lambda_item + theta_user * alpha_item # utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs # utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs # utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs # utility: lambda_item + theta_user * alpha_item - nfact_category * gamma_user * price_obs -utility: theta_user * alpha_item - nfact_category * gamma_user * price_obs +# utility: theta_user * alpha_item + nfact_category * gamma_user * price_obs # utility: -nfact_category * gamma_user * price_obs # utility: -nfact_category * gamma_user * price_obs -# utility: -gamma_user * price_obs +utility: -gamma_user * price_obs # utility: lambda_item out_dir: ./output/ # model configuration. @@ -39,8 +39,9 @@ coef_dist_dict: default: 'gaussian' # gamma_user: 'gamma' # nfact_category: 'gamma' - gamma_user: 'gaussian' - nfact_category: 'lognormal' + # gamma_user: 'gamma' + gamma_user: 'lognormal' + # nfact_category: 'gamma' prior_mean: default: 0.0 # mean is shape for gamma variable @@ -53,8 +54,8 @@ prior_variance: default: 100000.0 # variance is rate for gamma variable # shape is mean / var - gamma_user: 10.0 - nfact_category: 10.0 + gamma_user: 2.0 + nfact_category: 2.0 #### optimization. trace_log_q: False shuffle: False @@ -63,7 +64,7 @@ num_epochs: 200 learning_rate: 0.03 num_mc_seeds: 1 num_price_obs: 1 -obs_user: True +obs_user: False obs_item: False patience: 10 complete_availability: True