diff --git a/source/hsicbt/math/hsic.py b/source/hsicbt/math/hsic.py index 214880e..932e1f3 100644 --- a/source/hsicbt/math/hsic.py +++ b/source/hsicbt/math/hsic.py @@ -1,7 +1,8 @@ import torch import numpy as np from torch.autograd import Variable, grad - +# pylint: disable=no-member +# pylint: disable=not-callable def sigma_estimation(X, Y): """ sigma from median distance """ @@ -17,42 +18,45 @@ def sigma_estimation(X, Y): return med def distmat(X): - """ distance matrix + """ distance matrix |X.X - 2(X x Xt) + (X.X)t| + Args + X (tensor) shape (batchsize, dims) """ - r = torch.sum(X*X, 1) - r = r.view([-1, 1]) - a = torch.mm(X, torch.transpose(X,0,1)) - D = r.expand_as(a) - 2*a + torch.transpose(r,0,1).expand_as(a) - D = torch.abs(D) - return D - -def kernelmat(X, sigma): + out = torch.mm(X, X.T).mul_(-2.0) + out.add_((X*X).sum(1, keepdim=True)) + out.add_((X*X).sum(1, keepdim=True).T) + return out.abs_() + +def kernelmat(X, sigma=None): """ kernel matrix baker + Args + X (tensor) shape (batchsize, dims) + sigma (float [None]) from config """ - m = int(X.size()[0]) - dim = int(X.size()[1]) * 1.0 - H = torch.eye(m) - (1./m) * torch.ones([m,m]) - Dxx = distmat(X) - + m, dim = X.size() + H = torch.eye(m, device=X.device).sub_(1/m) + Kx = distmat(X) + if sigma: - variance = 2.*sigma*sigma*X.size()[1] - Kx = torch.exp( -Dxx / variance).type(torch.FloatTensor) # kernel matrices - # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) + variance = 2.*sigma*sigma*dim + torch.exp_(Kx.mul_(-1.0/variance)) else: try: - sx = sigma_estimation(X,X) - Kx = torch.exp( -Dxx / (2.*sx*sx)).type(torch.FloatTensor) + sx = sigma_estimation(X, X) + variance = 2.*sx*sx + torch.exp_(Kx.mul_(-1.0/variance)) except RuntimeError as e: raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format( sx, torch.max(X), torch.min(X))) - Kxc = torch.mm(Kx,H) - + Kxc = torch.mm(Kx, H) + del H + del Kx return Kxc def distcorr(X, sigma=1.0): X = distmat(X) - X = torch.exp( -X / (2.*sigma*sigma)) + X = torch.exp(-X / (2.*sigma*sigma)) return torch.mean(X) def compute_kernel(x, y): @@ -137,21 +141,27 @@ def hsic_normalized(x, y, sigma=None, use_cuda=True, to_numpy=True): thehsic = Pxy/(Px*Py) return thehsic -def hsic_normalized_cca(x, y, sigma, use_cuda=True, to_numpy=True): +def hsic_normalized_cca(x, y, sigma=None): """ + Args + x (tensor) shape (batchsize, dims) + y (tensor) shape (batchsize, dims) + sigma (float [None]) """ - m = int(x.size()[0]) - Kxc = kernelmat(x, sigma=sigma) - Kyc = kernelmat(y, sigma=sigma) - epsilon = 1E-5 - K_I = torch.eye(m) - Kxc_i = torch.inverse(Kxc + epsilon*m*K_I) - Kyc_i = torch.inverse(Kyc + epsilon*m*K_I) - Rx = (Kxc.mm(Kxc_i)) - Ry = (Kyc.mm(Kyc_i)) - Pxy = torch.sum(torch.mul(Rx, Ry.t())) + m = x.size()[0] + K_I = torch.eye(m, device=x.device).mul_(epsilon*m) - return Pxy + Kc = kernelmat(x, sigma=sigma) + Rx = Kc.mm(Kc.add(K_I).inverse()) + + Kc = kernelmat(y, sigma=sigma) + Ry = Kc.mm(Kc.add(K_I).inverse()) + out = Rx.mul_(Ry.t()).sum() + del Rx + del Ry + del Kc + del K_I + return out