diff --git a/bemb/model/bayesian_coefficient.py b/bemb/model/bayesian_coefficient.py index 0de89cb..0a6460f 100644 --- a/bemb/model/bayesian_coefficient.py +++ b/bemb/model/bayesian_coefficient.py @@ -89,26 +89,28 @@ def __init__(self, self.mean_clamp = (np.log(0.1), np.log(100.0)) self.logstd_clamp = (np.log(0.1), np.log(10000.0)) shape = prior_mean ** 2 / prior_variance - assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' rate = prior_mean / prior_variance - assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' - # prior_mean stores ln(shape) for gamma - prior_mean = np.log(shape) - # prior_variance stores rate for gamma - prior_variance = rate - # prior_mean = np.log(prior_mean) - # prior_variance = prior_variance + if not obs2prior: + assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' + assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' + # prior_mean stores ln(shape) for gamma + prior_mean = np.log(shape) + # prior_variance stores rate for gamma + prior_variance = rate + # prior_mean = np.log(prior_mean) + # prior_variance = prior_variance elif distribution == 'lognormal': # mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2) # prior_mean in -2, exp(3) - self.mean_clamp = (-2.0, np.exp(1.5)) + self.mean_clamp = (-10.0, np.exp(1.5)) # sigma sq clamp exp(-20), exp(1.5) # therefore sigma in (exp(-10), exp(0.75)) # therefore log sigma in (-10, 0.75) self.logstd_clamp = (-10.0, 0.75) - assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' - assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' + if not obs2prior: + assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' + assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' # assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}' # assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}' @@ -254,6 +256,7 @@ def variational_mean(self) -> torch.Tensor: elif self.distribution == 'lognormal': M = (torch.minimum((M.exp() + 0.000001), torch.tensor(1e3))) + # M = torch.minimum(M + 0.000001, torch.tensor(1e3)) if self.is_H and (self.H_zero_mask is not None): # a H-variable with zero-entry restriction. @@ -314,24 +317,23 @@ def log_prior(self, cov_diag=self.prior_cov_diag).log_prob(sample) elif self.distribution == 'lognormal': mu = torch.clamp(mu, min=-100.0, max=10.0) - out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) - out = torch.sum(out, dim=-1) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) + # out = torch.sum(out, dim=-1) # out = torch.zeros((num_seeds, num_classes), device=sample.device) elif self.distribution == 'gamma': mu = torch.clamp(mu, min=-100.0, max=4.0) - concentration = torch.exp(mu) - rate = self.prior_variance - ''' - print(concentration) - print('concentration') - print(rate) - print('rate') - ''' - out = Gamma(concentration=concentration, - rate=rate).log_prob(sample) - # sum over the last dimension - out = torch.sum(out, dim=-1) + out = LowRankMultivariateNormal(loc=mu, + cov_factor=self.prior_cov_factor, + cov_diag=self.prior_cov_diag).log_prob(sample) + # concentration = torch.exp(mu) + # rate = self.prior_variance + # out = Gamma(concentration=concentration, + # rate=rate).log_prob(sample) + # out = torch.sum(out, dim=-1) # out = torch.zeros((num_seeds, num_classes), device=sample.device) assert out.shape == (num_seeds, num_classes) @@ -443,7 +445,7 @@ def clamp_params(self) -> None: self.variational_mean_flexible.data = torch.clamp( self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) self.variational_logstd.data = torch.clamp( - self.variational_logstd.data, min=self.mean_clamp[0], max=self.logstd_clamp[1]) + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) # elif self.distribution == 'gamma': # self.variational_mean_flexible.data = torch.clamp( # self.variational_mean_flexible.data, min=-2.0, max=4) diff --git a/tutorials/supermarket/configs3_1.yaml b/tutorials/supermarket/configs3_1.yaml index 5a3971d..c2bbb7e 100644 --- a/tutorials/supermarket/configs3_1.yaml +++ b/tutorials/supermarket/configs3_1.yaml @@ -3,19 +3,19 @@ device: cuda # data_dir: /home/tianyudu/Data/MoreSupermarket/20180101-20191231_13/tsv/ # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/20180101-20191231_42/tsv # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_9999/2 -data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/ +# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1904-1887-1974-2012-1992/20180101-20191231_44/tsv/ # data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/m_1887/chunked_1 -# data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1 +data_dir: /oak/stanford/groups/athey/MoreSupermarkets/csv/new_data/nf_runs/sims34/1 # utility: lambda_item # utility: lambda_item + theta_user * alpha_item # utility: lambda_item + theta_user * alpha_item + zeta_user * item_obs # utility: lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs # utility: lambda_item + theta_user * alpha_item + gamma_user * price_obs # utility: lambda_item + theta_user * alpha_item - nfact_category * gamma_user * price_obs -utility: theta_user * alpha_item - nfact_category * gamma_user * price_obs +# utility: theta_user * alpha_item + nfact_category * gamma_user * price_obs # utility: -nfact_category * gamma_user * price_obs # utility: -nfact_category * gamma_user * price_obs -# utility: -gamma_user * price_obs +utility: -gamma_user * price_obs # utility: lambda_item out_dir: ./output/ # model configuration. @@ -39,8 +39,9 @@ coef_dist_dict: default: 'gaussian' # gamma_user: 'gamma' # nfact_category: 'gamma' - gamma_user: 'gaussian' - nfact_category: 'lognormal' + # gamma_user: 'gamma' + gamma_user: 'lognormal' + # nfact_category: 'gamma' prior_mean: default: 0.0 # mean is shape for gamma variable @@ -53,8 +54,8 @@ prior_variance: default: 100000.0 # variance is rate for gamma variable # shape is mean / var - gamma_user: 10.0 - nfact_category: 10.0 + gamma_user: 2.0 + nfact_category: 2.0 #### optimization. trace_log_q: False shuffle: False @@ -63,7 +64,7 @@ num_epochs: 200 learning_rate: 0.03 num_mc_seeds: 1 num_price_obs: 1 -obs_user: True +obs_user: False obs_item: False patience: 10 complete_availability: True