Skip to content

Commit

Permalink
Merge pull request #175 from masa-su/develop/v0.3.3
Browse files Browse the repository at this point in the history
Develop/v0.3.3
  • Loading branch information
masa-su authored Dec 14, 2021
2 parents be15a93 + a9c250e commit a9baf06
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 204 deletions.
210 changes: 84 additions & 126 deletions examples/cvae.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pixyz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name = "pixyz"
__version__ = "0.3.2"
__version__ = "0.3.3"
2 changes: 2 additions & 0 deletions pixyz/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

from .poe import ProductOfNormal, ElementWiseProductOfNormal
from .moe import MixtureOfNormal

from .mixture_distributions import MixtureModel

Expand Down Expand Up @@ -55,6 +56,7 @@
'MarginalizeVarDistribution',
'ProductOfNormal',
'ElementWiseProductOfNormal',
'MixtureOfNormal',
'MixtureModel',

'TransformedDistribution',
Expand Down
87 changes: 69 additions & 18 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def _reversed_name_dict(self):
def __apply_dict(dict, var):
return [dict[var_name] if var_name in dict else var_name for var_name in var]

def sample(self, values, sample_option):
global_input_var = self.__apply_dict(self._reversed_name_dict, self.dist.input_var)
def _get_local_input_dict(self, values, input_var=None):
if not input_var:
input_var = self.dist.input_var
global_input_var = self.__apply_dict(self._reversed_name_dict, input_var)

if any(var_name not in values for var_name in global_input_var):
raise ValueError("lack of some condition variables")
raise ValueError("lack of some variables")
input_dict = get_dict_values(values, global_input_var, return_dict=True)

local_input_dict = replace_dict_keys(input_dict, self.name_dict)
return local_input_dict

def sample(self, values, sample_option):
local_input_dict = self._get_local_input_dict(values)

# Overwrite log_prob_option with self.option to give priority to local settings such as batch_n
option = dict(sample_option)
Expand All @@ -94,19 +100,34 @@ def sample(self, values, sample_option):
return sample

def get_log_prob(self, values, log_prob_option):
global_input_var = self.__apply_dict(self._reversed_name_dict, list(self.dist.var) + list(self.dist.cond_var))

if any(var_name not in values for var_name in global_input_var):
raise ValueError("lack of some variables")
input_dict = get_dict_values(values, global_input_var, return_dict=True)
local_input_dict = replace_dict_keys(input_dict, self.name_dict)
local_input_dict = self._get_local_input_dict(values, list(self.dist.var) + list(self.dist.cond_var))

# Overwrite log_prob_option with self.option to give priority to local settings such as batch_n
option = dict(log_prob_option)
option.update(self.option)
log_prob = self.dist.get_log_prob(local_input_dict, **option)
return log_prob

def get_params(self, params_dict={}, **kwargs):
orig_params_dict = self._get_local_input_dict(params_dict)
params = self.dist.get_params(orig_params_dict, **kwargs)
return params

def sample_mean(self, values={}):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.sample_mean(local_input_dict)
return result

def sample_variance(self, values={}):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.sample_variance(local_input_dict)
return result

def get_entropy(self, values={}, sum_features=True, feature_dims=None):
local_input_dict = self._get_local_input_dict(values)
result = self.dist.get_entropy(local_input_dict, sum_features, feature_dims)
return result

@property
def input_var(self):
return self.__apply_dict(self._reversed_name_dict, self.dist.input_var)
Expand Down Expand Up @@ -673,6 +694,34 @@ def _get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
return 0
return log_prob

def get_params(self, params_dict={}, **kwargs):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.get_params(params_dict, **kwargs)
return result

def sample_mean(self, x_dict={}):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.sample_variance(x_dict)
return result

def sample_variance(self, x_dict={}):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.sample_variance(x_dict)
return result

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
if len(self.var) != 1:
raise NotImplementedError()
for factor in self.factors():
result = factor.get_entropy(x_dict, sum_features, feature_dims)
return result

@property
def has_reparam(self):
return all(factor.dist.has_reparam for factor in self.factors())
Expand Down Expand Up @@ -1043,6 +1092,8 @@ def sample_mean(self, x_dict={}):
1.2810, -0.6681]])
"""
if self.graph:
return self.graph.sample_mean(x_dict)
raise NotImplementedError()

def sample_variance(self, x_dict={}):
Expand Down Expand Up @@ -1073,6 +1124,8 @@ def sample_variance(self, x_dict={}):
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
if self.graph:
return self.graph.sample_variance(x_dict)
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
Expand Down Expand Up @@ -1154,6 +1207,13 @@ def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
tensor([14.1894])
"""
if self.graph:
return self.graph.get_entropy(x_dict, sum_features, feature_dims)
raise NotImplementedError()

def get_params(self, params_dict={}, **kwargs):
if self.graph:
return self.graph.get_params(params_dict, **kwargs)
raise NotImplementedError()

def log_prob(self, sum_features=True, feature_dims=None):
Expand Down Expand Up @@ -1700,15 +1760,6 @@ def __repr__(self):
def forward(self, *args, **kwargs):
return self.p(*args, **kwargs)

def sample_mean(self, x_dict={}):
return self.p.sample_mean(x_dict)

def sample_variance(self, x_dict={}):
return self.p.sample_variance(x_dict)

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
return self.p.get_entropy(x_dict, sum_features, feature_dims)

@property
def distribution_name(self):
return self.p.distribution_name
Expand Down
169 changes: 169 additions & 0 deletions pixyz/distributions/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from __future__ import print_function
import torch
from torch import nn
import numpy as np

from ..utils import tolist, get_dict_values
from ..distributions import Normal


class MixtureOfNormal(Normal):
r"""Mixture of normal distributions.
.. math::
p(z|x,y) = p(z|x) + p(z|y)
In this models, :math:`p(z|x)` and :math:`p(a|y)` perform as `experts`.
References
----------
[Shi+ 2019] Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models
"""

def __init__(self, p=[], weight_modalities=None, name="p", features_shape=torch.Size()):
"""
Parameters
----------
p : :obj:`list` of :class:`pixyz.distributions.Normal`.
List of experts.
name : :obj:`str`, defaults to "p"
Name of this distribution.
This name is displayed in prob_text and prob_factorized_text.
features_shape : :obj:`torch.Size` or :obj:`list`, defaults to torch.Size())
Shape of dimensions (features) of this distribution.
"""

p = tolist(p)
if len(p) == 0:
raise ValueError()

if weight_modalities is None:
weight_modalities = torch.ones(len(p)) / float(len(p))

elif len(weight_modalities) != len(p):
raise ValueError()

var = p[0].var
cond_var = []

for _p in p:
if _p.var != var:
raise ValueError()

cond_var += _p.cond_var

cond_var = list(set(cond_var))

super().__init__(var=var, cond_var=cond_var, name=name, features_shape=features_shape)
self.p = nn.ModuleList(p)
self.weight_modalities = weight_modalities

def _get_expert_params(self, params_dict={}, **kwargs):
"""Get the output parameters of all experts.
Parameters
----------
params_dict : dict
**kwargs
Arbitrary keyword arguments.
Returns
-------
loc : torch.Tensor
Concatenation of mean vectors for specified experts. (n_expert, n_batch, output_dim)
scale : torch.Tensor
Concatenation of the square root of a diagonal covariance matrix for specified experts.
(n_expert, n_batch, output_dim)
weight : np.array
(n_expert, )
"""

loc = []
scale = []

for i, _p in enumerate(self.p):
inputs_dict = get_dict_values(params_dict, _p.cond_var, True)
if len(inputs_dict) != 0:
outputs = _p.get_params(inputs_dict, **kwargs)
loc.append(outputs["loc"])
scale.append(outputs["scale"])

loc = torch.stack(loc)
scale = torch.stack(scale)

return loc, scale

def get_params(self, params_dict={}, **kwargs):
# experts
if len(params_dict) > 0:
loc, scale = self._get_expert_params(params_dict, **kwargs) # (n_expert, n_batch, output_dim)
else:
raise ValueError()

output_loc, output_scale = self._compute_expert_params(loc, scale)
output_dict = {"loc": output_loc, "scale": output_scale}

return output_dict

def _compute_expert_params(self, loc, scale):
"""Compute parameters for the product of experts.
Is is assumed that unspecified experts are excluded from inputs.
Parameters
----------
loc : torch.Tensor
Concatenation of mean vectors for specified experts. (n_expert, n_batch, output_dim)
scale : torch.Tensor
Concatenation of the square root of a diagonal covariance matrix for specified experts.
(n_expert, n_batch, output_dim)
Returns
-------
output_loc : torch.Tensor
Mean vectors for this distribution. (n_batch, output_dim)
output_scale : torch.Tensor
The square root of diagonal covariance matrices for this distribution. (n_batch, output_dim)
"""
num_samples = loc.shape[1]

idx_start = []
idx_end = []
for k in range(0, len(self.weight_modalities)):
if k == 0:
i_start = 0
else:
i_start = int(idx_end[k - 1])
if k == len(self.weight_modalities) - 1:
i_end = num_samples
else:
i_end = i_start + int(np.floor(num_samples * self.weight_modalities[k]))
idx_start.append(i_start)
idx_end.append(i_end)

idx_end[-1] = num_samples

output_loc = torch.cat([loc[k, idx_start[k]:idx_end[k], :] for k in range(len(self.weight_modalities))])
output_scale = torch.cat([scale[k, idx_start[k]:idx_end[k], :] for k in range(len(self.weight_modalities))])

return output_loc, output_scale

def _get_input_dict(self, x, var=None):
if var is None:
var = self.input_var

if type(x) is torch.Tensor:
checked_x = {var[0]: x}

elif type(x) is list:
# TODO: we need to check if all the elements contained in this list are torch.Tensor.
checked_x = dict(zip(var, x))

elif type(x) is dict:
# point of modification
checked_x = x

else:
raise ValueError("The type of input is not valid, got %s." % type(x))

return get_dict_values(checked_x, var, return_dict=True)

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
log_prob = torch.stack([w * p.get_log_prob(x_dict, sum_features=sum_features, feature_dims=feature_dims) for p, w in zip(self.p, self.weight_modalities)])
log_prob = torch.logsumexp(log_prob, dim=0)

return log_prob
Loading

0 comments on commit a9baf06

Please sign in to comment.