forked from YingfanWang/PaCMAP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
specify_nn_demo.py
51 lines (41 loc) · 1.81 KB
/
specify_nn_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pacmap
import numpy as np
import matplotlib.pyplot as plt
from annoy import AnnoyIndex
# loading preprocessed coil_20 dataset
X = np.load("../data/coil_20.npy", allow_pickle=True)
X = X.reshape(X.shape[0], -1)
y = np.load("../data/coil_20_labels.npy", allow_pickle=True)
# create nearest neighbor pairs
# here we use AnnoyIndex as an example, but the process can be done by any
# external NN library that provides neighbors into a matrix of the shape
# (n, n_neighbors_extra), where n_neighbors_extra is greater or equal to
# n_neighbors in the following example.
n, dim = X.shape
n_neighbors = 10
tree = AnnoyIndex(dim, metric='euclidean')
for i in range(n):
tree.add_item(i, X[i, :])
tree.build(20)
nbrs = np.zeros((n, 20), dtype=np.int32)
for i in range(n):
nbrs_ = tree.get_nns_by_item(i, 20 + 1) # The first nbr is always the point itself
nbrs[i, :] = nbrs_[1:]
scaled_dist = np.ones((n, n_neighbors)) # No scaling is needed
# Type casting is needed for numba acceleration
X = X.astype(np.float32)
scaled_dist = scaled_dist.astype(np.float32)
# make sure n_neighbors is the same number you want when fitting the data
pair_neighbors = pacmap.sample_neighbors_pair(X, scaled_dist, nbrs, np.int32(n_neighbors))
# initializing the pacmap instance
# feed the pair_neighbors into the instance
embedding = pacmap.LocalMAP(n_components=2,
n_neighbors=n_neighbors,
MN_ratio=0.5,
FP_ratio=2.0,
pair_neighbors=pair_neighbors)
# fit the data (The index of transformed data corresponds to the index of the original data)
X_transformed = embedding.fit_transform(X, init="pca")
# visualize the embedding
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(X_transformed[:, 0], X_transformed[:, 1], cmap="Spectral", c=y, s=0.6)