Implement grid search and determine best parameters for C and sigma
This commit is contained in:
		
							parent
							
								
									e67166bc8e
								
							
						
					
					
						commit
						203cbc997c
					
				
					 1 changed files with 31 additions and 4 deletions
				
			
		|  | @ -2,8 +2,8 @@ function [C, sigma] = dataset3Params(X, y, Xval, yval) | ||||||
| %EX6PARAMS returns your choice of C and sigma for Part 3 of the exercise | %EX6PARAMS returns your choice of C and sigma for Part 3 of the exercise | ||||||
| %where you select the optimal (C, sigma) learning parameters to use for SVM | %where you select the optimal (C, sigma) learning parameters to use for SVM | ||||||
| %with RBF kernel | %with RBF kernel | ||||||
| %   [C, sigma] = EX6PARAMS(X, y, Xval, yval) returns your choice of C and  | %   [C, sigma] = EX6PARAMS(X, y, Xval, yval) returns your choice of C and | ||||||
| %   sigma. You should complete this function to return the optimal C and  | %   sigma. You should complete this function to return the optimal C and | ||||||
| %   sigma based on a cross-validation set. | %   sigma based on a cross-validation set. | ||||||
| % | % | ||||||
| 
 | 
 | ||||||
|  | @ -15,19 +15,46 @@ sigma = 0.3; | ||||||
| % Instructions: Fill in this function to return the optimal C and sigma | % Instructions: Fill in this function to return the optimal C and sigma | ||||||
| %               learning parameters found using the cross validation set. | %               learning parameters found using the cross validation set. | ||||||
| %               You can use svmPredict to predict the labels on the cross | %               You can use svmPredict to predict the labels on the cross | ||||||
| %               validation set. For example,  | %               validation set. For example, | ||||||
| %                   predictions = svmPredict(model, Xval); | %                   predictions = svmPredict(model, Xval); | ||||||
| %               will return the predictions on the cross validation set. | %               will return the predictions on the cross validation set. | ||||||
| % | % | ||||||
| %  Note: You can compute the prediction error using  | %  Note: You can compute the prediction error using | ||||||
| %        mean(double(predictions ~= yval)) | %        mean(double(predictions ~= yval)) | ||||||
| % | % | ||||||
| 
 | 
 | ||||||
|  | grid_search = 0; | ||||||
| 
 | 
 | ||||||
|  | if grid_search | ||||||
|  |   % Grid search | ||||||
|  |   load ex6data3.mat | ||||||
|  |   %grid = [0.01, 0.03]; | ||||||
|  |   grid = [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30]; | ||||||
|  |   results = []; | ||||||
|  |   for C = grid | ||||||
|  |     for sigma = grid | ||||||
| 
 | 
 | ||||||
|  |       fprintf('== C = %.2f, sigma = %.2f\n', C, sigma); | ||||||
|  |       model= svmTrain(X, y, C, @(x1, x2) gaussianKernel(x1, x2, sigma)); | ||||||
|  |       predictions = svmPredict(model, Xval); | ||||||
|  |       error = mean(double(predictions ~= yval)); | ||||||
|  |       fprintf('error = %.2f\n\n', error); | ||||||
| 
 | 
 | ||||||
|  |       results(end + 1,:) = [C, sigma, error]; | ||||||
| 
 | 
 | ||||||
|  |     end | ||||||
|  |   end | ||||||
| 
 | 
 | ||||||
|  |   [_, best_i] = min(results(:,3)); | ||||||
|  |   C = results(best_i, 1); | ||||||
|  |   sigma = results(best_i, 2); | ||||||
|  |   error = results(best_i, 3); | ||||||
|  |   fprintf('Best: C = %.2f, sigma = %.2f with error = %.2f\n', C, sigma, error); | ||||||
|  | else | ||||||
|  |   % Found through the grid search above | ||||||
|  |   C = 1.00; | ||||||
|  |   sigma = 0.10; | ||||||
|  | end | ||||||
| 
 | 
 | ||||||
| % ========================================================================= | % ========================================================================= | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Reference in a new issue