-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
70 lines (60 loc) · 1.76 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 23 00:16:44 2019
@author: urixs
"""
import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
from torch.autograd import Variable
def compute_dist_mat(X, Y=None, device=torch.device("cpu")):
"""
Computes nxm matrix of squared distances
args:
X: nxd tensor of data points
Y: mxd tensor of data points (optional)
"""
if Y is None:
Y = X
X = X.to(device=device)
Y = Y.to(device=device)
dtype = X.data.type()
dist_mat = Variable(torch.Tensor(X.size()[0], Y.size()[0]).type(dtype)).to(device=device)
for i, row in enumerate(X.split(1)):
r_v = row.expand_as(Y)
sq_dist = torch.sum((r_v - Y) ** 2, 1)
dist_mat[i] = sq_dist.view(1, -1)
return dist_mat
def nn_search(X, Y=None, k=10):
"""
Computes nearest neighbors in Y for points in X
args:
X: nxd tensor of query points
Y: mxd tensor of data points (optional)
k: number of neighbors
"""
if Y is None:
Y = X
X = X.cpu().detach().numpy()
Y = Y.cpu().detach().numpy()
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(Y)
Dis, Ids = nbrs.kneighbors(X)
return Dis, Ids
def compute_scale(Dis, k=5):
"""
Computes scale as the max distance to the k neighbor
args:
Dis: nxk' numpy array of distances (output of nn_search)
k: number of neighbors
"""
scale = np.median(Dis[:, k - 1])
return scale
def compute_kernel_mat(D, scale, device=torch.device('cpu')):
"""
Computes RBF kernal matrix
args:
D: nxn tenosr of squared distances
scale: standard dev
"""
W = torch.exp(-D / (scale ** 2))
return W