You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
5.8 KiB
Matlab
193 lines
5.8 KiB
Matlab
function [model] = svmTrain(X, Y, C, kernelFunction, ...
|
|
tol, max_passes)
|
|
%SVMTRAIN Trains an SVM classifier using a simplified version of the SMO
|
|
%algorithm.
|
|
% [model] = SVMTRAIN(X, Y, C, kernelFunction, tol, max_passes) trains an
|
|
% SVM classifier and returns trained model. X is the matrix of training
|
|
% examples. Each row is a training example, and the jth column holds the
|
|
% jth feature. Y is a column matrix containing 1 for positive examples
|
|
% and 0 for negative examples. C is the standard SVM regularization
|
|
% parameter. tol is a tolerance value used for determining equality of
|
|
% floating point numbers. max_passes controls the number of iterations
|
|
% over the dataset (without changes to alpha) before the algorithm quits.
|
|
%
|
|
% Note: This is a simplified version of the SMO algorithm for training
|
|
% SVMs. In practice, if you want to train an SVM classifier, we
|
|
% recommend using an optimized package such as:
|
|
%
|
|
% LIBSVM (http://www.csie.ntu.edu.tw/~cjlin/libsvm/)
|
|
% SVMLight (http://svmlight.joachims.org/)
|
|
%
|
|
%
|
|
|
|
if ~exist('tol', 'var') || isempty(tol)
|
|
tol = 1e-3;
|
|
end
|
|
|
|
if ~exist('max_passes', 'var') || isempty(max_passes)
|
|
max_passes = 5;
|
|
end
|
|
|
|
% Data parameters
|
|
m = size(X, 1);
|
|
n = size(X, 2);
|
|
|
|
% Map 0 to -1
|
|
Y(Y==0) = -1;
|
|
|
|
% Variables
|
|
alphas = zeros(m, 1);
|
|
b = 0;
|
|
E = zeros(m, 1);
|
|
passes = 0;
|
|
eta = 0;
|
|
L = 0;
|
|
H = 0;
|
|
|
|
% Pre-compute the Kernel Matrix since our dataset is small
|
|
% (in practice, optimized SVM packages that handle large datasets
|
|
% gracefully will _not_ do this)
|
|
%
|
|
% We have implemented optimized vectorized version of the Kernels here so
|
|
% that the svm training will run faster.
|
|
if strcmp(func2str(kernelFunction), 'linearKernel')
|
|
% Vectorized computation for the Linear Kernel
|
|
% This is equivalent to computing the kernel on every pair of examples
|
|
K = X*X';
|
|
elseif strfind(func2str(kernelFunction), 'gaussianKernel')
|
|
% Vectorized RBF Kernel
|
|
% This is equivalent to computing the kernel on every pair of examples
|
|
X2 = sum(X.^2, 2);
|
|
K = bsxfun(@plus, X2, bsxfun(@plus, X2', - 2 * (X * X')));
|
|
K = kernelFunction(1, 0) .^ K;
|
|
else
|
|
% Pre-compute the Kernel Matrix
|
|
% The following can be slow due to the lack of vectorization
|
|
K = zeros(m);
|
|
for i = 1:m
|
|
for j = i:m
|
|
K(i,j) = kernelFunction(X(i,:)', X(j,:)');
|
|
K(j,i) = K(i,j); %the matrix is symmetric
|
|
end
|
|
end
|
|
end
|
|
|
|
% Train
|
|
fprintf('\nTraining ...');
|
|
dots = 12;
|
|
while passes < max_passes,
|
|
|
|
num_changed_alphas = 0;
|
|
for i = 1:m,
|
|
|
|
% Calculate Ei = f(x(i)) - y(i) using (2).
|
|
% E(i) = b + sum (X(i, :) * (repmat(alphas.*Y,1,n).*X)') - Y(i);
|
|
E(i) = b + sum (alphas.*Y.*K(:,i)) - Y(i);
|
|
|
|
if ((Y(i)*E(i) < -tol && alphas(i) < C) || (Y(i)*E(i) > tol && alphas(i) > 0)),
|
|
|
|
% In practice, there are many heuristics one can use to select
|
|
% the i and j. In this simplified code, we select them randomly.
|
|
j = ceil(m * rand());
|
|
while j == i, % Make sure i \neq j
|
|
j = ceil(m * rand());
|
|
end
|
|
|
|
% Calculate Ej = f(x(j)) - y(j) using (2).
|
|
E(j) = b + sum (alphas.*Y.*K(:,j)) - Y(j);
|
|
|
|
% Save old alphas
|
|
alpha_i_old = alphas(i);
|
|
alpha_j_old = alphas(j);
|
|
|
|
% Compute L and H by (10) or (11).
|
|
if (Y(i) == Y(j)),
|
|
L = max(0, alphas(j) + alphas(i) - C);
|
|
H = min(C, alphas(j) + alphas(i));
|
|
else
|
|
L = max(0, alphas(j) - alphas(i));
|
|
H = min(C, C + alphas(j) - alphas(i));
|
|
end
|
|
|
|
if (L == H),
|
|
% continue to next i.
|
|
continue;
|
|
end
|
|
|
|
% Compute eta by (14).
|
|
eta = 2 * K(i,j) - K(i,i) - K(j,j);
|
|
if (eta >= 0),
|
|
% continue to next i.
|
|
continue;
|
|
end
|
|
|
|
% Compute and clip new value for alpha j using (12) and (15).
|
|
alphas(j) = alphas(j) - (Y(j) * (E(i) - E(j))) / eta;
|
|
|
|
% Clip
|
|
alphas(j) = min (H, alphas(j));
|
|
alphas(j) = max (L, alphas(j));
|
|
|
|
% Check if change in alpha is significant
|
|
if (abs(alphas(j) - alpha_j_old) < tol),
|
|
% continue to next i.
|
|
% replace anyway
|
|
alphas(j) = alpha_j_old;
|
|
continue;
|
|
end
|
|
|
|
% Determine value for alpha i using (16).
|
|
alphas(i) = alphas(i) + Y(i)*Y(j)*(alpha_j_old - alphas(j));
|
|
|
|
% Compute b1 and b2 using (17) and (18) respectively.
|
|
b1 = b - E(i) ...
|
|
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
|
|
- Y(j) * (alphas(j) - alpha_j_old) * K(i,j)';
|
|
b2 = b - E(j) ...
|
|
- Y(i) * (alphas(i) - alpha_i_old) * K(i,j)' ...
|
|
- Y(j) * (alphas(j) - alpha_j_old) * K(j,j)';
|
|
|
|
% Compute b by (19).
|
|
if (0 < alphas(i) && alphas(i) < C),
|
|
b = b1;
|
|
elseif (0 < alphas(j) && alphas(j) < C),
|
|
b = b2;
|
|
else
|
|
b = (b1+b2)/2;
|
|
end
|
|
|
|
num_changed_alphas = num_changed_alphas + 1;
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if (num_changed_alphas == 0),
|
|
passes = passes + 1;
|
|
else
|
|
passes = 0;
|
|
end
|
|
|
|
fprintf('.');
|
|
dots = dots + 1;
|
|
if dots > 78
|
|
dots = 0;
|
|
fprintf('\n');
|
|
end
|
|
if exist('OCTAVE_VERSION')
|
|
fflush(stdout);
|
|
end
|
|
end
|
|
fprintf(' Done! \n\n');
|
|
|
|
% Save the model
|
|
idx = alphas > 0;
|
|
model.X= X(idx,:);
|
|
model.y= Y(idx);
|
|
model.kernelFunction = kernelFunction;
|
|
model.b= b;
|
|
model.alphas= alphas(idx);
|
|
model.w = ((alphas.*Y)'*X)';
|
|
|
|
end
|