-
Notifications
You must be signed in to change notification settings - Fork 0
/
Neo_TSVM.m
66 lines (50 loc) · 1.49 KB
/
Neo_TSVM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
function [Pre_Label, Testacc, Trainacc, time, z1, z2] = Neo_TSVM(Xtrain, Ytrain, Xtest, Ytest, c1, c2, c3, c4)
%%explanation
%Xtrain dimension shd be n x f
%Ytrain dimension shd be n x 1, it shd have entries +1 or -1
%class A has points which hv label +1
%class B has points which hv label -1
D = Xtrain;
A = D(Ytrain==1,:);
B = D(Ytrain~=1,:);
m1 = size(A,1);
m2 = size(B,1);
A = [A ones(m1,1)];
B = [B ones(m2,1)];
D = [D ones(m1+m2,1)];
Xtest = [Xtest, ones(size(Xtest,1),1)];
f = size(D,2);
%tol_val = 10e-4;
% iter = 0;
%max_iter = 50;
%max_iter = 2;
%diff =1000;
%...................
tic;
%I = 2*c3*eye(f);
% denomin1_inv = inv(2 * c1*(A'*A) + I);
% denomin2_inv = inv(2 * c1*(B'*B) + I);
meanA=A'*(ones(m1,1))/m1;
meanB=B'*(ones(m2,1))/m2;
time1=toc;
%.................
tic;
%z1 = -(2 * c1*(A'*A) + I) \ meanB;
%z2 = -(2 * c1*(B'*B) + I) \ (meanA + c2*z1);
z2 = -inv(2 * (c1*(B'*B) + c3*eye(f))) * (c4*meanA);
%c4=0;
z1 = -inv(2 * (c1*(A'*A) + c3*eye(f))) * (c4*meanB + c2*z2);
time2 = toc;
%................
time = time1 + time2;
proj1 = (D*z1)/norm(z1(1:end-1));
proj2 = (D*z2)/norm(z2(1:end-1));
t=abs(proj2) - abs(proj1);
train_Label = sign( abs(proj2) - abs(proj1) );
Trainacc = length(find(train_Label==Ytrain))/ length(Ytrain)*100;
clear proj1 proj2
proj1 = Xtest*z1/norm(z1(1:end-1));
proj2 = Xtest*z2/norm(z2(1:end-1));
Pre_Label = sign( abs(proj2) - abs(proj1) );
Testacc = length(find(Pre_Label==Ytest))/ length(Ytest)*100;
end