Implement grid search and determine best parameters for C and sigma
This commit is contained in:
		
							parent
							
								
									e67166bc8e
								
							
						
					
					
						commit
						203cbc997c
					
				
					 1 changed files with 31 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -23,11 +23,38 @@ sigma = 0.3;
 | 
			
		|||
%        mean(double(predictions ~= yval))
 | 
			
		||||
%
 | 
			
		||||
 | 
			
		||||
grid_search = 0;
 | 
			
		||||
 | 
			
		||||
if grid_search
 | 
			
		||||
  % Grid search
 | 
			
		||||
  load ex6data3.mat
 | 
			
		||||
  %grid = [0.01, 0.03];
 | 
			
		||||
  grid = [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30];
 | 
			
		||||
  results = [];
 | 
			
		||||
  for C = grid
 | 
			
		||||
    for sigma = grid
 | 
			
		||||
 | 
			
		||||
      fprintf('== C = %.2f, sigma = %.2f\n', C, sigma);
 | 
			
		||||
      model= svmTrain(X, y, C, @(x1, x2) gaussianKernel(x1, x2, sigma));
 | 
			
		||||
      predictions = svmPredict(model, Xval);
 | 
			
		||||
      error = mean(double(predictions ~= yval));
 | 
			
		||||
      fprintf('error = %.2f\n\n', error);
 | 
			
		||||
 | 
			
		||||
      results(end + 1,:) = [C, sigma, error];
 | 
			
		||||
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  [_, best_i] = min(results(:,3));
 | 
			
		||||
  C = results(best_i, 1);
 | 
			
		||||
  sigma = results(best_i, 2);
 | 
			
		||||
  error = results(best_i, 3);
 | 
			
		||||
  fprintf('Best: C = %.2f, sigma = %.2f with error = %.2f\n', C, sigma, error);
 | 
			
		||||
else
 | 
			
		||||
  % Found through the grid search above
 | 
			
		||||
  C = 1.00;
 | 
			
		||||
  sigma = 0.10;
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
% =========================================================================
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Reference in a new issue