digits: Clean up code a bit
This commit is contained in:
		
							parent
							
								
									347f9edc8d
								
							
						
					
					
						commit
						6f6d06ec91
					
				
					 1 changed files with 46 additions and 32 deletions
				
			
		
							
								
								
									
										78
									
								
								digits.py
									
										
									
									
									
								
							
							
						
						
									
										78
									
								
								digits.py
									
										
									
									
									
								
							|  | @ -1,11 +1,12 @@ | |||
| #!/usr/bin/env python | ||||
| # vim:tabstop=4 shiftwidth=4 tw=79: | ||||
| 
 | ||||
| ''' | ||||
| SVM, Random forest and KNearest digit recognition. | ||||
| Modified from the OpenCV example. | ||||
| 
 | ||||
| Sample loads a dataset of handwritten digits from '../data/digits.png'. | ||||
| Then it trains a Random Forest, SVM and KNearest classifiers on it and evaluates | ||||
| Sample loads a dataset of handwritten digits from '../data/digits.png'.  Then | ||||
| it trains a Random Forest, SVM and KNearest classifiers on it and evaluates | ||||
| their accuracy. | ||||
| 
 | ||||
| Following preprocessing is applied to the dataset: | ||||
|  | @ -25,22 +26,20 @@ Usage: | |||
|      digits.py | ||||
| ''' | ||||
| 
 | ||||
| # built-in modules | ||||
| from multiprocessing.pool import ThreadPool | ||||
| 
 | ||||
| import cv2 | ||||
| 
 | ||||
| import numpy as np | ||||
| from numpy.linalg import norm | ||||
| 
 | ||||
| # local modules | ||||
| from common import clock, mosaic | ||||
| from common import mosaic | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| SZ = 20 # size of each digit is SZ x SZ | ||||
| SZ = 20  # size of each digit is SZ x SZ | ||||
| CLASS_N = 10 | ||||
| DIGITS_FN = 'digits.png' | ||||
| SIMPLE = True  # Use simple preprocessing or HOG features (for SVM) | ||||
| 
 | ||||
| 
 | ||||
| def split2d(img, cell_size, flatten=True): | ||||
|     h, w = img.shape[:2] | ||||
|  | @ -51,6 +50,7 @@ def split2d(img, cell_size, flatten=True): | |||
|         cells = cells.reshape(-1, sy, sx) | ||||
|     return cells | ||||
| 
 | ||||
| 
 | ||||
| def load_digits(fn): | ||||
|     print 'loading "%s" ...' % fn | ||||
|     digits_img = cv2.imread(fn, 0) | ||||
|  | @ -58,23 +58,28 @@ def load_digits(fn): | |||
|     labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N) | ||||
|     return digits, labels | ||||
| 
 | ||||
| 
 | ||||
| def deskew(img): | ||||
|     m = cv2.moments(img) | ||||
|     if abs(m['mu02']) < 1e-2: | ||||
|         return img.copy() | ||||
|     skew = m['mu11']/m['mu02'] | ||||
|     M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]]) | ||||
|     img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) | ||||
|     img = cv2.warpAffine(img, M, (SZ, SZ), | ||||
|                          flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) | ||||
|     return img | ||||
| 
 | ||||
| 
 | ||||
| class StatModel(object): | ||||
|     def load(self, fn): | ||||
|         self.model.load(fn) | ||||
| 
 | ||||
|     def save(self, fn): | ||||
|         self.model.save(fn) | ||||
| 
 | ||||
| 
 | ||||
| class KNearest(StatModel): | ||||
|     def __init__(self, k = 3): | ||||
|     def __init__(self, k=3): | ||||
|         self.k = k | ||||
|         self.model = cv2.KNearest() | ||||
| 
 | ||||
|  | @ -83,42 +88,47 @@ class KNearest(StatModel): | |||
|         self.model.train(samples, responses) | ||||
| 
 | ||||
|     def predict(self, samples): | ||||
|         retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k) | ||||
|         retval, results, neigh_resp, dists = self.model.find_nearest(samples, | ||||
|                                                                      self.k) | ||||
|         return results.ravel() | ||||
| 
 | ||||
| 
 | ||||
| class SVM(StatModel): | ||||
|     def __init__(self, kernel_type=cv2.SVM_RBF, C=1, gamma=0.5): | ||||
|         self.params = dict( kernel_type = kernel_type, | ||||
|                             svm_type = cv2.SVM_C_SVC, | ||||
|                             C = C, | ||||
|                             gamma = gamma ) | ||||
|         self.params = dict(kernel_type=kernel_type, | ||||
|                            svm_type=cv2.SVM_C_SVC, | ||||
|                            C=C, | ||||
|                            gamma=gamma) | ||||
|         self.model = cv2.SVM() | ||||
| 
 | ||||
|     def train(self, samples, responses): | ||||
|         self.model = cv2.SVM() | ||||
|         self.model.train(samples, responses, params = self.params) | ||||
|         self.model.train(samples, responses, params=self.params) | ||||
| 
 | ||||
|     def train_auto(self, samples, responses): | ||||
|         self.model = cv2.SVM() | ||||
|         self.model.train_auto(samples, responses, None, None, params = self.params) | ||||
|         self.model.train_auto(samples, responses, None, None, | ||||
|                               params=self.params) | ||||
| 
 | ||||
|     def predict(self, samples): | ||||
|         return self.model.predict_all(samples).ravel() | ||||
| 
 | ||||
| 
 | ||||
| class RForest(StatModel): | ||||
|     def __init__(self): | ||||
|         self.params = dict( max_depth=20 ) | ||||
|         self.params = dict(max_depth=20) | ||||
|         self.model = cv2.RTrees() | ||||
| 
 | ||||
|     def train(self, samples, responses): | ||||
|         self.model = cv2.RTrees() | ||||
|         self.model.train(samples, cv2.CV_ROW_SAMPLE, responses, | ||||
|               params=self.params) | ||||
|                          params=self.params) | ||||
| 
 | ||||
|     def predict(self, samples): | ||||
|         predictions = map(self.model.predict, samples) | ||||
|         return predictions | ||||
| 
 | ||||
| 
 | ||||
| def evaluate_model(model, digits, samples, labels): | ||||
|     resp = model.predict(samples) | ||||
|     err = (labels != resp).mean() | ||||
|  | @ -135,13 +145,15 @@ def evaluate_model(model, digits, samples, labels): | |||
|     for img, flag in zip(digits, resp == labels): | ||||
|         img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | ||||
|         if not flag: | ||||
|             img[...,:2] = 0 | ||||
|             img[..., :2] = 0 | ||||
|         vis.append(img) | ||||
|     return mosaic(25, vis) | ||||
| 
 | ||||
| 
 | ||||
| def preprocess_simple(digits): | ||||
|     return np.float32(digits).reshape(-1, SZ*SZ) / 255.0 | ||||
| 
 | ||||
| 
 | ||||
| def preprocess_hog(digits): | ||||
|     samples = [] | ||||
|     for img in digits: | ||||
|  | @ -150,9 +162,10 @@ def preprocess_hog(digits): | |||
|         mag, ang = cv2.cartToPolar(gx, gy) | ||||
|         bin_n = 16 | ||||
|         bin = np.int32(bin_n*ang/(2*np.pi)) | ||||
|         bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:] | ||||
|         mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:] | ||||
|         hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)] | ||||
|         bin_cells = bin[:10, :10], bin[10:, :10], bin[:10, 10:], bin[10:, 10:] | ||||
|         mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:] | ||||
|         hists = [np.bincount(b.ravel(), m.ravel(), bin_n) | ||||
|                  for b, m in zip(bin_cells, mag_cells)] | ||||
|         hist = np.hstack(hists) | ||||
| 
 | ||||
|         # transform to Hellinger kernel | ||||
|  | @ -177,8 +190,10 @@ if __name__ == '__main__': | |||
|     digits, labels = digits[shuffle], labels[shuffle] | ||||
| 
 | ||||
|     digits2 = map(deskew, digits) | ||||
|     samples = preprocess_simple(digits2) | ||||
|     #samples = preprocess_hog(digits2) | ||||
|     if SIMPLE: | ||||
|         samples = preprocess_simple(digits2) | ||||
|     else: | ||||
|         samples = preprocess_hog(digits2) | ||||
| 
 | ||||
|     train_n = int(0.9*len(samples)) | ||||
|     cv2.imshow('test set', mosaic(25, digits[train_n:])) | ||||
|  | @ -201,17 +216,16 @@ if __name__ == '__main__': | |||
| 
 | ||||
|     print 'training SVM...' | ||||
| 
 | ||||
|     # HOG (original digits.py) | ||||
|     #model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383) | ||||
|     #model.train(samples_train, labels_train) | ||||
|     # Simple (cross-validation) | ||||
|     model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1) | ||||
|     model.train(samples_train, labels_train) | ||||
|     if SIMPLE: | ||||
|         model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1) | ||||
|         model.train(samples_train, labels_train) | ||||
|     else: | ||||
|         model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383) | ||||
|         model.train(samples_train, labels_train) | ||||
| 
 | ||||
|     vis = evaluate_model(model, digits_test, samples_test, labels_test) | ||||
|     cv2.imshow('SVM test', vis) | ||||
|     print 'saving SVM as "digits_svm.dat"...' | ||||
|     model.save('digits_svm.dat') | ||||
| 
 | ||||
| 
 | ||||
|     cv2.waitKey(0) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue