1
0
Fork 0

Add prediction function for one-vs-all classification

master
neingeist 10 years ago
parent cf0d25440c
commit 3bf3d9fdc3

@ -1,17 +1,17 @@
function p = predictOneVsAll(all_theta, X) function p = predictOneVsAll(all_theta, X)
%PREDICT Predict the label for a trained one-vs-all classifier. The labels %PREDICT Predict the label for a trained one-vs-all classifier. The labels
%are in the range 1..K, where K = size(all_theta, 1). %are in the range 1..K, where K = size(all_theta, 1).
% p = PREDICTONEVSALL(all_theta, X) will return a vector of predictions % p = PREDICTONEVSALL(all_theta, X) will return a vector of predictions
% for each example in the matrix X. Note that X contains the examples in % for each example in the matrix X. Note that X contains the examples in
% rows. all_theta is a matrix where the i-th row is a trained logistic % rows. all_theta is a matrix where the i-th row is a trained logistic
% regression theta vector for the i-th class. You should set p to a vector % regression theta vector for the i-th class. You should set p to a vector
% of values from 1..K (e.g., p = [1; 3; 1; 2] predicts classes 1, 3, 1, 2 % of values from 1..K (e.g., p = [1; 3; 1; 2] predicts classes 1, 3, 1, 2
% for 4 examples) % for 4 examples)
m = size(X, 1); m = size(X, 1);
num_labels = size(all_theta, 1); num_labels = size(all_theta, 1);
% You need to return the following variables correctly % You need to return the following variables correctly
p = zeros(size(X, 1), 1); p = zeros(size(X, 1), 1);
% Add ones to the X data matrix % Add ones to the X data matrix
@ -24,17 +24,14 @@ X = [ones(m, 1) X];
% num_labels). % num_labels).
% %
% Hint: This code can be done all vectorized using the max function. % Hint: This code can be done all vectorized using the max function.
% In particular, the max function can also return the index of the % In particular, the max function can also return the index of the
% max element, for more information see 'help max'. If your examples % max element, for more information see 'help max'. If your examples
% are in rows, then, you can use max(A, [], 2) to obtain the max % are in rows, then, you can use max(A, [], 2) to obtain the max
% for each row. % for each row.
% %
[~, p] = max(X * all_theta', [], 2);
%disp(size(p));
% ========================================================================= % =========================================================================