-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainSVM.m
125 lines (95 loc) · 2.77 KB
/
trainSVM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
function [model] = trainSVM(X, Y, C, kernelFunction, ...
tol, maxpass)
if ~exist('tol', 'var') || isempty(tol)
tol = 1e-3;
end
if ~exist('maxpass', 'var') || isempty(maxpass)
maxpass = 5;
end
m = size(X, 1);
n = size(X, 2);
Y(Y==0) = -1;
alphas = zeros(m, 1);
b = 0;
E = zeros(m, 1);
pass = 0;
eta = 0;
L = 0;
H = 0;
K=X*X';
fprintf('\nTraining ...');
dots = 12;
while pass < maxpass,
num_changed_alphas = 0;
for i = 1:m,
E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);
if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),
j = ceil(m * rand());
while j == i,
j = ceil(m * rand());
end
E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);
alpha_i_old = alphas(i);
alpha_j_old = alphas(j);
if (Y(i) == Y(j)),
L = max(0, alphas(j) + alphas(i) - C);
H = min(C, alphas(j) + alphas(i));
else
L = max(0, alphas(j) - alphas(i));
H = min(C, C + alphas(j) - alphas(i));
end
if (L == H),
continue;
end
eta = 2 * K(i,j) - K(i,i) - K(j,j);
if (eta >= 0),
continue;
end
alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;
alphas(j) = min (H, alphas(j));
alphas(j) = max (L, alphas(j));
if (abs(alphas(j) - alpha_j_old) < tol),
alphas(j) = alpha_j_old;
continue;
end
alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));
b1 = b - E(i) ...
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
- Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';
b2 = b - E(j) ...
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
- Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';
if (0 < alphas(i) && alphas(i) < C),
b = b1;
elseif (0 < alphas(j) && alphas(j) < C),
b = b2;
else
b = (b1+b2)/2;
end
num_changed_alphas = num_changed_alphas + 1;
end
end
if (num_changed_alphas == 0),
pass = pass + 1;
else
pass = 0;
end
fprintf('.');
dots = dots + 1;
if dots > 78
dots = 0;
fprintf('\n');
end
if exist('OCTAVE_VERSION')
fflush(stdout);
end
end
fprintf(' Done! \n\n');
idx = alphas > 0;
model.X= X(idx,:);
model.y= Y(idx);
model.kernelFunction = kernelFunction;
model.b= b;
model.alphas= alphas(idx);
model.w = ((alphas.*Y)'*X)';
end