Skip to content

Commit

Permalink
Merge pull request #176 from masa-su/fix/replace_var_graph_dev
Browse files Browse the repository at this point in the history
Fix/replace var graph dev
  • Loading branch information
masa-su authored Dec 14, 2021
2 parents 322d1d6 + c6a0fc8 commit a9c250e
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 19 deletions.
87 changes: 69 additions & 18 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def _reversed_name_dict(self):
def __apply_dict(dict, var):
return [dict[var_name] if var_name in dict else var_name for var_name in var]

def sample(self, values, sample_option):
global_input_var = self.__apply_dict(self._reversed_name_dict, self.dist.input_var)
def _get_local_input_dict(self, values, input_var=None):
if not input_var:
input_var = self.dist.input_var
global_input_var = self.__apply_dict(self._reversed_name_dict, input_var)

if any(var_name not in values for var_name in global_input_var):
raise ValueError("lack of some condition variables")
raise ValueError("lack of some variables")
input_dict = get_dict_values(values, global_input_var, return_dict=True)

local_input_dict = replace_dict_keys(input_dict, self.name_dict)
return local_input_dict

def sample(self, values, sample_option):
local_input_dict = self._get_local_input_dict(values)

# Overwrite log_prob_option with self.option to give priority to local settings such as batch_n
option = dict(sample_option)
Expand All @@ -94,19 +100,34 @@ def sample(self, values, sample_option):
return sample

def get_log_prob(self, values, log_prob_option):
global_input_var = self.__apply_dict(self._reversed_name_dict, list(self.dist.var) + list(self.dist.cond_var))

if any(var_name not in values for var_name in global_input_var):
raise ValueError("lack of some variables")
input_dict = get_dict_values(values, global_input_var, return_dict=True)
local_input_dict = replace_dict_keys(input_dict, self.name_dict)
local_input_dict = self._get_local_input_dict(values, list(self.dist.var) + list(self.dist.cond_var))

# Overwrite log_prob_option with self.option to give priority to local settings such as batch_n
option = dict(log_prob_option)
option.update(self.option)
log_prob = self.dist.get_log_prob(local_input_dict, **option)
return log_prob

def get_params(self, params_dict={}, **kwargs):
orig_params_dict = self._get_local_input_dict(params_dict)
params = self.dist.get_params(orig_params_dict, **kwargs)
return params

def sample_mean(self, values={}):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.sample_mean(local_input_dict)
return result

def sample_variance(self, values={}):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.sample_variance(local_input_dict)
return result

def get_entropy(self, values={}, sum_features=True, feature_dims=None):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.get_entropy(local_input_dict, sum_features, feature_dims)
return result

@property
def input_var(self):
return self.__apply_dict(self._reversed_name_dict, self.dist.input_var)
Expand Down Expand Up @@ -673,6 +694,34 @@ def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
return 0
return log_prob

def get_params(self, params_dict={}, **kwargs):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.get_params(params_dict, **kwargs)
return result

def sample_mean(self, x_dict={}):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.sample_variance(x_dict)
return result

def sample_variance(self, x_dict={}):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.sample_variance(x_dict)
return result

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.get_entropy(x_dict, sum_features, feature_dims)
return result

@property
def has_reparam(self):
return all(factor.dist.has_reparam for factor in self.factors())
Expand Down Expand Up @@ -1043,6 +1092,8 @@ def sample_mean(self, x_dict={}):
1.2810, -0.6681]])
"""
if self.graph:
return self.graph.sample_mean(x_dict)
raise NotImplementedError()

def sample_variance(self, x_dict={}):
Expand Down Expand Up @@ -1073,6 +1124,8 @@ def sample_variance(self, x_dict={}):
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
if self.graph:
return self.graph.sample_variance(x_dict)
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
Expand Down Expand Up @@ -1154,6 +1207,13 @@ def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
tensor([14.1894])
"""
if self.graph:
return self.graph.get_entropy(x_dict, sum_features, feature_dims)
raise NotImplementedError()

def get_params(self, params_dict={}, **kwargs):
if self.graph:
return self.graph.get_params(params_dict, **kwargs)
raise NotImplementedError()

def log_prob(self, sum_features=True, feature_dims=None):
Expand Down Expand Up @@ -1700,15 +1760,6 @@ def __repr__(self):
def forward(self, *args, **kwargs):
return self.p(*args, **kwargs)

def sample_mean(self, x_dict={}):
return self.p.sample_mean(x_dict)

def sample_variance(self, x_dict={}):
return self.p.sample_variance(x_dict)

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
return self.p.get_entropy(x_dict, sum_features, feature_dims)

@property
def distribution_name(self):
return self.p.distribution_name
Expand Down
63 changes: 62 additions & 1 deletion tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,67 @@ def test_unknown_option(self, dist):
dist.get_log_prob(x_dict, unknown_opt=None)


class TestReplaceVarDistribution:
def test_get_params(self):
dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1)
result = dist.get_params({'y': torch.ones(1)})
assert list(result.keys()) == ['loc', 'scale']

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z')
result = dist.get_params({'z': torch.ones(1)})
assert list(result.keys()) == ['loc', 'scale']

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z')
with pytest.raises(ValueError):
dist.get_params({'y': torch.ones(1)})

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(x='z')
result = dist.get_params({'y': torch.ones(1)})
assert list(result.keys()) == ['loc', 'scale']

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(var=['y'], loc=0, scale=1)
with pytest.raises(NotImplementedError):
dist.get_params()

def test_sample_mean(self):
dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1)
result = dist.sample_mean({'y': torch.ones(1)})
assert result == torch.ones(1)

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z')
result = dist.sample_mean({'z': torch.ones(1)})
assert result == torch.ones(1)

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z')
with pytest.raises(ValueError):
dist.sample_mean({'y': torch.ones(1)})

def test_sample_variance(self):
dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y')
result = dist.sample_variance({'y': torch.ones(1)})
assert result == torch.ones(1)

dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y').replace_var(y='z')
result = dist.sample_variance({'z': torch.ones(1)})
assert result == torch.ones(1)

dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y').replace_var(y='z')
with pytest.raises(ValueError):
dist.sample_variance({'y': torch.ones(1)})

def test_get_entropy(self):
dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1)
truth = dist.get_entropy({'y': torch.ones(1)})

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z', x='y')
result = dist.get_entropy({'z': torch.ones(1)})
assert result == truth

dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z')
with pytest.raises(ValueError):
dist.get_entropy({'y': torch.ones(1)})


class TestMixtureDistribution:
def test_sample_mean(self):
dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1., 2.])))
Expand Down Expand Up @@ -193,4 +254,4 @@ def test_save_dist(tmpdir, no_contiguous_tensor):


if __name__ == "__main__":
test_save_dist(".", torch.zeros(2, 3))
TestReplaceVarDistribution().test_get_entropy()

0 comments on commit a9c250e

Please sign in to comment.