-
Notifications
You must be signed in to change notification settings - Fork 2
/
helper_generate_kmeans.py
94 lines (72 loc) · 2.95 KB
/
helper_generate_kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import sys
import warnings
from typing import Tuple
import numpy as np
import phate
sys.path.append('../')
from utils.attribute_hashmap import AttributeHashmap
from utils.metrics import per_class_dice_coeff
from utils.segmentation import label_hint_seg
warnings.filterwarnings("ignore")
def generate_kmeans(shape: Tuple[int],
latent: np.array,
label_true: np.array,
num_workers: int = 1,
random_seed: int = 1) -> Tuple[float, np.array, np.array]:
H, W, C = shape
assert latent.shape == (H * W, C)
seg_true = label_true > 0
# Very occasionally, SVD won't converge.
try:
clusters = phate_clustering(latent=latent,
random_seed=random_seed,
num_workers=num_workers)
except:
clusters = phate_clustering(latent=latent,
random_seed=random_seed + 1,
num_workers=num_workers)
# [H x W, C] to [H, W, C]
label_pred = clusters.reshape((H, W))
seg_pred = label_hint_seg(label_pred=label_pred, label_true=label_true)
return per_class_dice_coeff(seg_pred, seg_true), label_pred, seg_pred
def phate_clustering(latent: np.array, random_seed: int,
num_workers: int) -> np.array:
phate_operator = phate.PHATE(n_components=3,
knn=100,
n_landmark=500,
t=2,
verbose=False,
random_state=random_seed,
n_jobs=num_workers)
phate_operator.fit_transform(latent)
clusters = phate.cluster.kmeans(phate_operator,
n_clusters=10,
random_state=random_seed)
return clusters
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load_path', type=str, required=True)
parser.add_argument('--save_path', type=str, required=True)
parser.add_argument('--num_workers', type=int, required=True)
args = vars(parser.parse_args())
args = AttributeHashmap(args)
numpy_array = np.load(args.load_path)
image = numpy_array['image']
label_true = numpy_array['label']
latent = numpy_array['latent']
image = (image + 1) / 2
H, W = label_true.shape[:2]
C = latent.shape[-1]
X = latent
dice_score, label_pred, seg_pred = generate_kmeans(
(H, W, C), latent, label_true, num_workers=args.num_workers)
with open(args.save_path, 'wb+') as f:
np.savez(f,
image=image,
label=label_true,
latent=latent,
label_kmeans=label_pred,
seg_kmeans=seg_pred)
sys.stdout.write('SUCCESS! %s, dice: %s' %
(args.load_path.split('/')[-1], dice_score))