diff --git a/pixyz/distributions/distributions.py b/pixyz/distributions/distributions.py index a93aa0c..3780fba 100644 --- a/pixyz/distributions/distributions.py +++ b/pixyz/distributions/distributions.py @@ -69,14 +69,20 @@ def _reversed_name_dict(self): def __apply_dict(dict, var): return [dict[var_name] if var_name in dict else var_name for var_name in var] - def sample(self, values, sample_option): - global_input_var = self.__apply_dict(self._reversed_name_dict, self.dist.input_var) + def _get_local_input_dict(self, values, input_var=None): + if not input_var: + input_var = self.dist.input_var + global_input_var = self.__apply_dict(self._reversed_name_dict, input_var) if any(var_name not in values for var_name in global_input_var): - raise ValueError("lack of some condition variables") + raise ValueError("lack of some variables") input_dict = get_dict_values(values, global_input_var, return_dict=True) local_input_dict = replace_dict_keys(input_dict, self.name_dict) + return local_input_dict + + def sample(self, values, sample_option): + local_input_dict = self._get_local_input_dict(values) # Overwrite log_prob_option with self.option to give priority to local settings such as batch_n option = dict(sample_option) @@ -94,12 +100,7 @@ def sample(self, values, sample_option): return sample def get_log_prob(self, values, log_prob_option): - global_input_var = self.__apply_dict(self._reversed_name_dict, list(self.dist.var) + list(self.dist.cond_var)) - - if any(var_name not in values for var_name in global_input_var): - raise ValueError("lack of some variables") - input_dict = get_dict_values(values, global_input_var, return_dict=True) - local_input_dict = replace_dict_keys(input_dict, self.name_dict) + local_input_dict = self._get_local_input_dict(values, list(self.dist.var) + list(self.dist.cond_var)) # Overwrite log_prob_option with self.option to give priority to local settings such as batch_n option = dict(log_prob_option) @@ -107,6 +108,26 @@ def get_log_prob(self, values, log_prob_option): log_prob = self.dist.get_log_prob(local_input_dict, **option) return log_prob + def get_params(self, params_dict={}, **kwargs): + orig_params_dict = self._get_local_input_dict(params_dict) + params = self.dist.get_params(orig_params_dict, **kwargs) + return params + + def sample_mean(self, values={}): + local_input_dict = self._get_local_input_dict(values) + result = self.dist.sample_mean(local_input_dict) + return result + + def sample_variance(self, values={}): + local_input_dict = self._get_local_input_dict(values) + result = self.dist.sample_variance(local_input_dict) + return result + + def get_entropy(self, values={}, sum_features=True, feature_dims=None): + local_input_dict = self._get_local_input_dict(values) + result = self.dist.get_entropy(local_input_dict, sum_features, feature_dims) + return result + @property def input_var(self): return self.__apply_dict(self._reversed_name_dict, self.dist.input_var) @@ -673,6 +694,34 @@ def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs): return 0 return log_prob + def get_params(self, params_dict={}, **kwargs): + if len(self.var) != 1: + raise NotImplementedError() + for factor in self.factors(): + result = factor.get_params(params_dict, **kwargs) + return result + + def sample_mean(self, x_dict={}): + if len(self.var) != 1: + raise NotImplementedError() + for factor in self.factors(): + result = factor.sample_variance(x_dict) + return result + + def sample_variance(self, x_dict={}): + if len(self.var) != 1: + raise NotImplementedError() + for factor in self.factors(): + result = factor.sample_variance(x_dict) + return result + + def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): + if len(self.var) != 1: + raise NotImplementedError() + for factor in self.factors(): + result = factor.get_entropy(x_dict, sum_features, feature_dims) + return result + @property def has_reparam(self): return all(factor.dist.has_reparam for factor in self.factors()) @@ -1043,6 +1092,8 @@ def sample_mean(self, x_dict={}): 1.2810, -0.6681]]) """ + if self.graph: + return self.graph.sample_mean(x_dict) raise NotImplementedError() def sample_variance(self, x_dict={}): @@ -1073,6 +1124,8 @@ def sample_variance(self, x_dict={}): tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) """ + if self.graph: + return self.graph.sample_variance(x_dict) raise NotImplementedError() def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs): @@ -1154,6 +1207,13 @@ def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): tensor([14.1894]) """ + if self.graph: + return self.graph.get_entropy(x_dict, sum_features, feature_dims) + raise NotImplementedError() + + def get_params(self, params_dict={}, **kwargs): + if self.graph: + return self.graph.get_params(params_dict, **kwargs) raise NotImplementedError() def log_prob(self, sum_features=True, feature_dims=None): @@ -1700,15 +1760,6 @@ def __repr__(self): def forward(self, *args, **kwargs): return self.p(*args, **kwargs) - def sample_mean(self, x_dict={}): - return self.p.sample_mean(x_dict) - - def sample_variance(self, x_dict={}): - return self.p.sample_variance(x_dict) - - def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None): - return self.p.get_entropy(x_dict, sum_features, feature_dims) - @property def distribution_name(self): return self.p.distribution_name diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index fdb970b..5139746 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -118,6 +118,67 @@ def test_unknown_option(self, dist): dist.get_log_prob(x_dict, unknown_opt=None) +class TestReplaceVarDistribution: + def test_get_params(self): + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) + result = dist.get_params({'y': torch.ones(1)}) + assert list(result.keys()) == ['loc', 'scale'] + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z') + result = dist.get_params({'z': torch.ones(1)}) + assert list(result.keys()) == ['loc', 'scale'] + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z') + with pytest.raises(ValueError): + dist.get_params({'y': torch.ones(1)}) + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(x='z') + result = dist.get_params({'y': torch.ones(1)}) + assert list(result.keys()) == ['loc', 'scale'] + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(var=['y'], loc=0, scale=1) + with pytest.raises(NotImplementedError): + dist.get_params() + + def test_sample_mean(self): + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) + result = dist.sample_mean({'y': torch.ones(1)}) + assert result == torch.ones(1) + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z') + result = dist.sample_mean({'z': torch.ones(1)}) + assert result == torch.ones(1) + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z') + with pytest.raises(ValueError): + dist.sample_mean({'y': torch.ones(1)}) + + def test_sample_variance(self): + dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y') + result = dist.sample_variance({'y': torch.ones(1)}) + assert result == torch.ones(1) + + dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y').replace_var(y='z') + result = dist.sample_variance({'z': torch.ones(1)}) + assert result == torch.ones(1) + + dist = Normal(var=['x'], cond_var=['y'], loc=2, scale='y').replace_var(y='z') + with pytest.raises(ValueError): + dist.sample_variance({'y': torch.ones(1)}) + + def test_get_entropy(self): + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) + truth = dist.get_entropy({'y': torch.ones(1)}) + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z', x='y') + result = dist.get_entropy({'z': torch.ones(1)}) + assert result == truth + + dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1).replace_var(y='z') + with pytest.raises(ValueError): + dist.get_entropy({'y': torch.ones(1)}) + + class TestMixtureDistribution: def test_sample_mean(self): dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1., 2.]))) @@ -193,4 +254,4 @@ def test_save_dist(tmpdir, no_contiguous_tensor): if __name__ == "__main__": - test_save_dist(".", torch.zeros(2, 3)) + TestReplaceVarDistribution().test_get_entropy()