diff --git a/bemb/model/bayesian_coefficient.py b/bemb/model/bayesian_coefficient.py index a12f490..0de89cb 100644 --- a/bemb/model/bayesian_coefficient.py +++ b/bemb/model/bayesian_coefficient.py @@ -86,8 +86,12 @@ def __init__(self, assert variance > 0, 'Gamma distribution requires variance > 0' # shape (concentration) is mean^2/variance, rate is variance/mean for Gamma distribution. ''' + self.mean_clamp = (np.log(0.1), np.log(100.0)) + self.logstd_clamp = (np.log(0.1), np.log(10000.0)) shape = prior_mean ** 2 / prior_variance + assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})' rate = prior_mean / prior_variance + assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})' # prior_mean stores ln(shape) for gamma prior_mean = np.log(shape) # prior_variance stores rate for gamma @@ -95,6 +99,20 @@ def __init__(self, # prior_mean = np.log(prior_mean) # prior_variance = prior_variance + elif distribution == 'lognormal': + # mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2) + # prior_mean in -2, exp(3) + self.mean_clamp = (-2.0, np.exp(1.5)) + # sigma sq clamp exp(-20), exp(1.5) + # therefore sigma in (exp(-10), exp(0.75)) + # therefore log sigma in (-10, 0.75) + self.logstd_clamp = (-10.0, 0.75) + assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}' + assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}' + # assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}' + # assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}' + + self.distribution = distribution self.obs2prior = obs2prior @@ -144,23 +162,49 @@ def __init__(self, num_classes, dim) * self.prior_variance) # create variational distribution. - if self.distribution == 'gaussian' or self.distribution == 'lognormal': + if self.distribution == 'gaussian': + if self.is_H: + self.variational_mean_flexible = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + # multiply by 0.0001 to avoid numerical issues. + self.variational_mean_flexible.data *= 0.0001 + + else: + self.variational_mean_flexible = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + elif self.distribution == 'lognormal': self.variational_mean_flexible = nn.Parameter( torch.randn(num_classes, dim), requires_grad=True) + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) # TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way. elif self.distribution == 'gamma': # initialize using uniform distribution between 0.5 and 1.5 # for a gamma distribution, we store the concentration (shape) as log(concentration) = variational_mean_flexible self.variational_mean_flexible = nn.Parameter( torch.rand(num_classes, dim) + 0.5, requires_grad=True) + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) if self.is_H and self.H_zero_mask is not None: assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \ f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} " # for gamma distribution, we store the rate as log(rate) = variational_logstd - self.variational_logstd = nn.Parameter( - torch.randn(num_classes, dim), requires_grad=True) + if self.distribution == 'gaussian': + self.variational_logstd = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + elif self.distribution == 'lognormal': + # uniform -1 to 1 + self.variational_logstd = nn.Parameter( + torch.rand(num_classes, dim) * 2 - 1, requires_grad=True) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) + elif self.distribution == 'gamma': + self.variational_logstd = nn.Parameter( + torch.randn(num_classes, dim), requires_grad=True) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1]) self.register_buffer('variational_cov_factor', torch.zeros(num_classes, dim, 1)) @@ -269,10 +313,13 @@ def log_prior(self, cov_factor=self.prior_cov_factor, cov_diag=self.prior_cov_diag).log_prob(sample) elif self.distribution == 'lognormal': - # out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) - # out = torch.sum(out, dim=-1) - out = torch.zeros((num_seeds, num_classes), device=sample.device) + mu = torch.clamp(mu, min=-100.0, max=10.0) + out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample) + out = torch.sum(out, dim=-1) + # out = torch.zeros((num_seeds, num_classes), device=sample.device) + elif self.distribution == 'gamma': + mu = torch.clamp(mu, min=-100.0, max=4.0) concentration = torch.exp(mu) rate = self.prior_variance ''' @@ -286,7 +333,7 @@ def log_prior(self, # sum over the last dimension out = torch.sum(out, dim=-1) - out = torch.zeros((num_seeds, num_classes), device=sample.device) + # out = torch.zeros((num_seeds, num_classes), device=sample.device) assert out.shape == (num_seeds, num_classes) return out @@ -323,9 +370,10 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor] """ value_sample = self.variational_distribution.rsample( torch.Size([num_seeds])) - if self.distribution == 'lognormal': - print(torch.min(value_sample)) - print(torch.max(value_sample)) + # if self.distribution == 'lognormal': + # print(torch.min(value_sample)) + # print(torch.max(value_sample)) + # breakpoint() # DEBUG_MARKER if self.obs2prior: # sample obs2prior H as well. @@ -347,20 +395,31 @@ def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]: elif self.distribution == 'lognormal': # print(self.variational_mean_flexible) # print(self.variational_logstd) - print(torch.max(self.variational_logstd), torch.min(self.variational_logstd)) - print(torch.max(self.variational_mean_flexible), torch.min(self.variational_mean_flexible)) - print(self.variational_mean_flexible.shape, self.variational_logstd.shape) + # print(torch.max(self.variational_logstd), torch.min(self.variational_logstd)) + # print(torch.max(self.variational_mean_flexible), torch.min(self.variational_mean_flexible)) + # print(self.variational_mean_flexible.shape, self.variational_logstd.shape) # return LowRankMultivariateNormal(loc=self.variational_mean_flexible, # cov_factor=self.variational_cov_factor, # cov_diag=torch.exp(self.variational_logstd)) - return LogNormal(loc=self.variational_mean_flexible, scale=torch.exp(self.variational_logstd)) + # variational_mean_flexible = torch.clamp(self.variational_mean_flexible, min=-10, max=10) + # variational_logstd = torch.clamp(self.variational_logstd, min=-4, max=3) + loc = self.variational_mean_flexible + scale = torch.exp(self.variational_logstd) + return LogNormal(loc=loc, scale=scale) elif self.distribution == 'gamma': # for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible - assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean' - concentration = self.variational_mean_flexible.exp() + 0.000001 - concentration = torch.minimum(concentration, torch.tensor(1e3)) + # assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean' + concentration = torch.exp(self.variational_mean_flexible) + # assert that all concentration should be between exp -4 and exp 4 + # assert torch.all(concentration > 0.1353 - 0.0001), 'concentration should be greater than exp -2' + # assert torch.all(concentration < 54.5981 + 0.0001), 'concentration should be less than exp 4' + # concentration = self.variational_mean_flexible.exp() + 0.000001 + # concentration = torch.clamp(concentration, min=1e-2, max=1e2) + # concentration = torch.minimum(concentration, torch.tensor(1e3)) # for gamma distribution, we store the rate as log(rate) = variational_logstd rate = torch.exp(self.variational_logstd) + # print(concentration, rate) + # rate = torch.clamp(rate, min=1e-2, max=1e2) return Gamma(concentration=concentration, rate=rate) else: raise NotImplementedError("Unknown variational distribution type.") @@ -369,3 +428,26 @@ def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]: def device(self) -> torch.device: """Returns the device of tensors contained in this module.""" return self.variational_mean.device + + def clamp_params(self) -> None: + """Clamps the parameters of the variational distribution to be within a reasonable range. + """ + if self.distribution == 'gaussian': + # do nothing + pass + # self.variational_mean_flexible.data = torch.clamp( + # self.variational_mean_flexible.data, min=-10, max=10) + # self.variational_logstd.data = torch.clamp( + # self.variational_logstd.data, min=-4, max=3) + elif self.distribution in ['lognormal', 'gamma']: + self.variational_mean_flexible.data = torch.clamp( + self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1]) + self.variational_logstd.data = torch.clamp( + self.variational_logstd.data, min=self.mean_clamp[0], max=self.logstd_clamp[1]) + # elif self.distribution == 'gamma': + # self.variational_mean_flexible.data = torch.clamp( + # self.variational_mean_flexible.data, min=-2.0, max=4) + # self.variational_logstd.data = torch.clamp( + # self.variational_logstd.data, min=-100.0, max=2.0) + else: + raise NotImplementedError("Unknown variational distribution type.") diff --git a/bemb/model/bemb.py b/bemb/model/bemb.py index d875d32..f6df9d9 100644 --- a/bemb/model/bemb.py +++ b/bemb/model/bemb.py @@ -422,6 +422,10 @@ def __init__(self, 'Additional modules are temporarily disabled for further development.') self.additional_modules = nn.ModuleList(additional_modules) + def clamp_coefs(self): + for coef_name in self.coef_dict.keys(): + self.coef_dict[coef_name].clamp_params() + def __str__(self): return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \ + f'Total number of parameters: {self.num_params}.\n' \ @@ -697,13 +701,14 @@ def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool = Fa for coef_name, coef in self.coef_dict.items(): if deterministic: s = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) + # print(torch.min(s), torch.max(s)) + # breakpoint() # if coef.distribution == 'lognormal': # s = torch.exp(s) sample_dict[coef_name] = s if coef.obs2prior: sample_dict[coef_name + '.H'] = coef.prior_H.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim) else: - print(coef_name) s = coef.rsample(num_seeds) if coef.obs2prior: # sample both obs2prior weight and realization of variable. diff --git a/bemb/model/bemb_chunked.py b/bemb/model/bemb_chunked.py index 5a8c5dc..b1572c8 100644 --- a/bemb/model/bemb_chunked.py +++ b/bemb/model/bemb_chunked.py @@ -358,7 +358,7 @@ def __init__(self, bayesian_coefs_inner = [] for jj in range(chunk_sizes[1]): if self.coef_dist_dict[coef_name] == 'gamma' and not self.obs2prior_dict[coef_name]: - assert mean > 0, 'shape of gamma distribution specifieid as prior_mean needs to be > 0' + assert mean > 0, 'shape of gamma distribution specified as prior_mean needs to be > 0' bayesian_coefs_inner.append(BayesianCoefficient(variation=variation, num_classes=variation_to_num_classes[variation], obs2prior=self.obs2prior_dict[coef_name], @@ -386,6 +386,13 @@ def __init__(self, 'Additional modules are temporarily disabled for further development.') self.additional_modules = nn.ModuleList(additional_modules) + + def clamp_coefs(self): + for coef_name in self.coef_dict.keys(): + for ii in range(len(self.coef_dict[coef_name])): + for jj in range(len(self.coef_dict[coef_name][ii])): + self.coef_dict[coef_name][ii][jj].clamp_params() + def __str__(self): return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \ + f'Total number of parameters: {self.num_params}.\n' \ @@ -1081,6 +1088,7 @@ def reshape_observable(obs, name): if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): obs_coeff = coef_sample.sum(dim=-1) price_coeffs = obs_coeff + price_coeffs *= term['sign'] additive_term = (coef_sample * obs).sum(dim=-1) additive_term *= term['sign'] @@ -1115,6 +1123,7 @@ def reshape_observable(obs, name): if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'): obs_coeff = coef.sum(dim=-1) price_coeffs = obs_coeff + price_coeffs *= term['sign'] additive_term = (coef * obs).sum(dim=-1) additive_term *= term['sign'] diff --git a/bemb/model/bemb_supermarket_lightning.py b/bemb/model/bemb_supermarket_lightning.py index a7b40cb..a26dc4e 100644 --- a/bemb/model/bemb_supermarket_lightning.py +++ b/bemb/model/bemb_supermarket_lightning.py @@ -150,6 +150,12 @@ def test_dataloader(self): num_workers=self.num_workers) return test_dataloader + # def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, ): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): + optimizer.step(closure=optimizer_closure) + with torch.no_grad(): + self.model.clamp_coefs() + def write_bemb_cpp_format(self): model = self.model