Skip to content

Commit

Permalink
gamma and lognormal coeffs
Browse files Browse the repository at this point in the history
  • Loading branch information
kanodiaayush committed Jan 30, 2024
1 parent 8236f92 commit b07d90a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 19 deletions.
116 changes: 99 additions & 17 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,33 @@ def __init__(self,
assert variance > 0, 'Gamma distribution requires variance > 0'
# shape (concentration) is mean^2/variance, rate is variance/mean for Gamma distribution.
'''
self.mean_clamp = (np.log(0.1), np.log(100.0))
self.logstd_clamp = (np.log(0.1), np.log(10000.0))
shape = prior_mean ** 2 / prior_variance
assert shape > np.exp(self.mean_clamp[0])**2 and shape < np.exp(self.mean_clamp[1])**2, f'Gamma shape {shape} is out of range, should be in ({np.exp(self.mean_clamp[0])**2}, {np.exp(self.mean_clamp[1])**2})'
rate = prior_mean / prior_variance
assert rate > np.exp(self.logstd_clamp[0]) and rate < np.exp(self.logstd_clamp[1]), f'Gamma rate {rate} is out of range, should be in ({np.exp(self.logstd_clamp[0])}, {np.exp(self.logstd_clamp[1])})'
# prior_mean stores ln(shape) for gamma
prior_mean = np.log(shape)
# prior_variance stores rate for gamma
prior_variance = rate
# prior_mean = np.log(prior_mean)
# prior_variance = prior_variance

elif distribution == 'lognormal':
# mean is exp(mu + sigma^2/2), variance is (exp(sigma^2) - 1) * exp(2*mu + sigma^2)
# prior_mean in -2, exp(3)
self.mean_clamp = (-2.0, np.exp(1.5))
# sigma sq clamp exp(-20), exp(1.5)
# therefore sigma in (exp(-10), exp(0.75))
# therefore log sigma in (-10, 0.75)
self.logstd_clamp = (-10.0, 0.75)
assert prior_mean > self.mean_clamp[0] and prior_mean < self.mean_clamp[1], f'Lognormal distribution requires prior_mean in {self.mean_clamp}, given {prior_mean}'
assert np.sqrt(prior_variance) > np.exp(self.logstd_clamp[0]) and np.sqrt(prior_variance) < np.exp(self.logstd_clamp[1]), f'Lognormal distribution requires prior_variance in {self.logstd_clamp}, given {prior_variance}'
# assert prior_mean > np.exp(-100.0) and prior_mean < np.exp(10.0), f'Lognormal distribution requires shape in (exp(-100), exp(10)), given {prior_mean}'
# assert prior_variance > np.exp(-100.0) and prior_variance < np.exp(2.0), f'Lognormal distribution requires rate in (exp(-100), exp(2)), given {prior_variance}'


self.distribution = distribution

self.obs2prior = obs2prior
Expand Down Expand Up @@ -144,23 +162,49 @@ def __init__(self,
num_classes, dim) * self.prior_variance)

# create variational distribution.
if self.distribution == 'gaussian' or self.distribution == 'lognormal':
if self.distribution == 'gaussian':
if self.is_H:
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
# multiply by 0.0001 to avoid numerical issues.
self.variational_mean_flexible.data *= 0.0001

else:
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
elif self.distribution == 'lognormal':
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
self.variational_mean_flexible.data = torch.clamp(
self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1])
# TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way.
elif self.distribution == 'gamma':
# initialize using uniform distribution between 0.5 and 1.5
# for a gamma distribution, we store the concentration (shape) as log(concentration) = variational_mean_flexible
self.variational_mean_flexible = nn.Parameter(
torch.rand(num_classes, dim) + 0.5, requires_grad=True)
self.variational_mean_flexible.data = torch.clamp(
self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1])

if self.is_H and self.H_zero_mask is not None:
assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \
f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} "

# for gamma distribution, we store the rate as log(rate) = variational_logstd
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
if self.distribution == 'gaussian':
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
elif self.distribution == 'lognormal':
# uniform -1 to 1
self.variational_logstd = nn.Parameter(
torch.rand(num_classes, dim) * 2 - 1, requires_grad=True)
self.variational_logstd.data = torch.clamp(
self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1])
elif self.distribution == 'gamma':
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
self.variational_logstd.data = torch.clamp(
self.variational_logstd.data, min=self.logstd_clamp[0], max=self.logstd_clamp[1])

self.register_buffer('variational_cov_factor',
torch.zeros(num_classes, dim, 1))
Expand Down Expand Up @@ -269,10 +313,13 @@ def log_prior(self,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'lognormal':
# out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample)
# out = torch.sum(out, dim=-1)
out = torch.zeros((num_seeds, num_classes), device=sample.device)
mu = torch.clamp(mu, min=-100.0, max=10.0)
out = LogNormal(loc=mu, scale=np.sqrt(self.prior_variance)).log_prob(sample)
out = torch.sum(out, dim=-1)
# out = torch.zeros((num_seeds, num_classes), device=sample.device)

elif self.distribution == 'gamma':
mu = torch.clamp(mu, min=-100.0, max=4.0)
concentration = torch.exp(mu)
rate = self.prior_variance
'''
Expand All @@ -286,7 +333,7 @@ def log_prior(self,
# sum over the last dimension
out = torch.sum(out, dim=-1)

out = torch.zeros((num_seeds, num_classes), device=sample.device)
# out = torch.zeros((num_seeds, num_classes), device=sample.device)
assert out.shape == (num_seeds, num_classes)
return out

Expand Down Expand Up @@ -323,9 +370,10 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
"""
value_sample = self.variational_distribution.rsample(
torch.Size([num_seeds]))
if self.distribution == 'lognormal':
print(torch.min(value_sample))
print(torch.max(value_sample))
# if self.distribution == 'lognormal':
# print(torch.min(value_sample))
# print(torch.max(value_sample))
# breakpoint()
# DEBUG_MARKER
if self.obs2prior:
# sample obs2prior H as well.
Expand All @@ -347,20 +395,31 @@ def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
elif self.distribution == 'lognormal':
# print(self.variational_mean_flexible)
# print(self.variational_logstd)
print(torch.max(self.variational_logstd), torch.min(self.variational_logstd))
print(torch.max(self.variational_mean_flexible), torch.min(self.variational_mean_flexible))
print(self.variational_mean_flexible.shape, self.variational_logstd.shape)
# print(torch.max(self.variational_logstd), torch.min(self.variational_logstd))
# print(torch.max(self.variational_mean_flexible), torch.min(self.variational_mean_flexible))
# print(self.variational_mean_flexible.shape, self.variational_logstd.shape)
# return LowRankMultivariateNormal(loc=self.variational_mean_flexible,
# cov_factor=self.variational_cov_factor,
# cov_diag=torch.exp(self.variational_logstd))
return LogNormal(loc=self.variational_mean_flexible, scale=torch.exp(self.variational_logstd))
# variational_mean_flexible = torch.clamp(self.variational_mean_flexible, min=-10, max=10)
# variational_logstd = torch.clamp(self.variational_logstd, min=-4, max=3)
loc = self.variational_mean_flexible
scale = torch.exp(self.variational_logstd)
return LogNormal(loc=loc, scale=scale)
elif self.distribution == 'gamma':
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean'
concentration = self.variational_mean_flexible.exp() + 0.000001
concentration = torch.minimum(concentration, torch.tensor(1e3))
# assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean'
concentration = torch.exp(self.variational_mean_flexible)
# assert that all concentration should be between exp -4 and exp 4
# assert torch.all(concentration > 0.1353 - 0.0001), 'concentration should be greater than exp -2'
# assert torch.all(concentration < 54.5981 + 0.0001), 'concentration should be less than exp 4'
# concentration = self.variational_mean_flexible.exp() + 0.000001
# concentration = torch.clamp(concentration, min=1e-2, max=1e2)
# concentration = torch.minimum(concentration, torch.tensor(1e3))
# for gamma distribution, we store the rate as log(rate) = variational_logstd
rate = torch.exp(self.variational_logstd)
# print(concentration, rate)
# rate = torch.clamp(rate, min=1e-2, max=1e2)
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")
Expand All @@ -369,3 +428,26 @@ def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
def device(self) -> torch.device:
"""Returns the device of tensors contained in this module."""
return self.variational_mean.device

def clamp_params(self) -> None:
"""Clamps the parameters of the variational distribution to be within a reasonable range.
"""
if self.distribution == 'gaussian':
# do nothing
pass
# self.variational_mean_flexible.data = torch.clamp(
# self.variational_mean_flexible.data, min=-10, max=10)
# self.variational_logstd.data = torch.clamp(
# self.variational_logstd.data, min=-4, max=3)
elif self.distribution in ['lognormal', 'gamma']:
self.variational_mean_flexible.data = torch.clamp(
self.variational_mean_flexible.data, min=self.mean_clamp[0], max=self.mean_clamp[1])
self.variational_logstd.data = torch.clamp(
self.variational_logstd.data, min=self.mean_clamp[0], max=self.logstd_clamp[1])
# elif self.distribution == 'gamma':
# self.variational_mean_flexible.data = torch.clamp(
# self.variational_mean_flexible.data, min=-2.0, max=4)
# self.variational_logstd.data = torch.clamp(
# self.variational_logstd.data, min=-100.0, max=2.0)
else:
raise NotImplementedError("Unknown variational distribution type.")
7 changes: 6 additions & 1 deletion bemb/model/bemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def __init__(self,
'Additional modules are temporarily disabled for further development.')
self.additional_modules = nn.ModuleList(additional_modules)

def clamp_coefs(self):
for coef_name in self.coef_dict.keys():
self.coef_dict[coef_name].clamp_params()

def __str__(self):
return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \
+ f'Total number of parameters: {self.num_params}.\n' \
Expand Down Expand Up @@ -697,13 +701,14 @@ def sample_coefficient_dictionary(self, num_seeds: int, deterministic: bool = Fa
for coef_name, coef in self.coef_dict.items():
if deterministic:
s = coef.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim)
# print(torch.min(s), torch.max(s))
# breakpoint()
# if coef.distribution == 'lognormal':
# s = torch.exp(s)
sample_dict[coef_name] = s
if coef.obs2prior:
sample_dict[coef_name + '.H'] = coef.prior_H.variational_distribution.mean.unsqueeze(dim=0) # (1, num_*, dim)
else:
print(coef_name)
s = coef.rsample(num_seeds)
if coef.obs2prior:
# sample both obs2prior weight and realization of variable.
Expand Down
11 changes: 10 additions & 1 deletion bemb/model/bemb_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(self,
bayesian_coefs_inner = []
for jj in range(chunk_sizes[1]):
if self.coef_dist_dict[coef_name] == 'gamma' and not self.obs2prior_dict[coef_name]:
assert mean > 0, 'shape of gamma distribution specifieid as prior_mean needs to be > 0'
assert mean > 0, 'shape of gamma distribution specified as prior_mean needs to be > 0'
bayesian_coefs_inner.append(BayesianCoefficient(variation=variation,
num_classes=variation_to_num_classes[variation],
obs2prior=self.obs2prior_dict[coef_name],
Expand Down Expand Up @@ -386,6 +386,13 @@ def __init__(self,
'Additional modules are temporarily disabled for further development.')
self.additional_modules = nn.ModuleList(additional_modules)


def clamp_coefs(self):
for coef_name in self.coef_dict.keys():
for ii in range(len(self.coef_dict[coef_name])):
for jj in range(len(self.coef_dict[coef_name][ii])):
self.coef_dict[coef_name][ii][jj].clamp_params()

def __str__(self):
return f'Bayesian EMBedding Model with U[user, item, session] = {self.raw_formula}\n' \
+ f'Total number of parameters: {self.num_params}.\n' \
Expand Down Expand Up @@ -1081,6 +1088,7 @@ def reshape_observable(obs, name):
if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'):
obs_coeff = coef_sample.sum(dim=-1)
price_coeffs = obs_coeff
price_coeffs *= term['sign']

additive_term = (coef_sample * obs).sum(dim=-1)
additive_term *= term['sign']
Expand Down Expand Up @@ -1115,6 +1123,7 @@ def reshape_observable(obs, name):
if return_price_coeff and term['observable'] is not None and term['observable'].startswith('price_'):
obs_coeff = coef.sum(dim=-1)
price_coeffs = obs_coeff
price_coeffs *= term['sign']

additive_term = (coef * obs).sum(dim=-1)
additive_term *= term['sign']
Expand Down
6 changes: 6 additions & 0 deletions bemb/model/bemb_supermarket_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def test_dataloader(self):
num_workers=self.num_workers)
return test_dataloader

# def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, ):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step(closure=optimizer_closure)
with torch.no_grad():
self.model.clamp_coefs()


def write_bemb_cpp_format(self):
model = self.model
Expand Down

0 comments on commit b07d90a

Please sign in to comment.