-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathCal_NMI.m
49 lines (41 loc) · 1.44 KB
/
Cal_NMI.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
function score = Cal_NMI(true_labels, cluster_labels)
true_labels = double(true_labels);
cluster_labels = double(cluster_labels);
true_labels = true_labels-min(true_labels(:))+1;
cluster_labels = cluster_labels-min(cluster_labels(:))+1;
if size(true_labels,2)>size(true_labels,1)
true_labels = true_labels';
end
if size(cluster_labels,2)>size(cluster_labels,1)
cluster_labels = cluster_labels';
end
n = length(true_labels);
cat = spconvert([(1:n)' true_labels ones(n,1)]);
cls = spconvert([(1:n)' cluster_labels ones(n,1)]);
cls = cls';
cmat = full(cls * cat);
n_i = sum(cmat, 1); % Total number of data for each true label (CAT), n_i
n_j = sum(cmat, 2); % Total number of data for each cluster label (CLS), n_j
% Calculate n*n_ij / n_i*n_j
[row, col] = size(cmat);
product = repmat(n_i, [row, 1]) .* repmat(n_j, [1, col]);
index = find(product > 0);
n = sum(cmat(:));
product(index) = (n*cmat(index)) ./ product(index);
% Sum up n_ij*log()
index = find(product > 0);
product(index) = log(product(index));
product = cmat .* product;
score = sum(product(:));
% Divide by sqrt( sum(n_i*log(n_i/n)) * sum(n_j*log(n_j/n)) )
index = find(n_i > 0);
n_i(index) = n_i(index) .* log(n_i(index)/n);
index = find(n_j > 0);
n_j(index) = n_j(index) .* log(n_j(index)/n);
denominator = sqrt(sum(n_i) * sum(n_j));
% Check if the denominator is zero
if denominator == 0
score = 0;
else
score = score / denominator;
end