diff --git a/digits.py b/digits.py index c553194..a94154b 100755 --- a/digits.py +++ b/digits.py @@ -1,11 +1,12 @@ #!/usr/bin/env python +# vim:tabstop=4 shiftwidth=4 tw=79: ''' SVM, Random forest and KNearest digit recognition. Modified from the OpenCV example. -Sample loads a dataset of handwritten digits from '../data/digits.png'. -Then it trains a Random Forest, SVM and KNearest classifiers on it and evaluates +Sample loads a dataset of handwritten digits from '../data/digits.png'. Then +it trains a Random Forest, SVM and KNearest classifiers on it and evaluates their accuracy. Following preprocessing is applied to the dataset: @@ -25,22 +26,20 @@ Usage: digits.py ''' -# built-in modules -from multiprocessing.pool import ThreadPool - import cv2 import numpy as np from numpy.linalg import norm # local modules -from common import clock, mosaic - +from common import mosaic -SZ = 20 # size of each digit is SZ x SZ +SZ = 20 # size of each digit is SZ x SZ CLASS_N = 10 DIGITS_FN = 'digits.png' +SIMPLE = True # Use simple preprocessing or HOG features (for SVM) + def split2d(img, cell_size, flatten=True): h, w = img.shape[:2] @@ -51,6 +50,7 @@ def split2d(img, cell_size, flatten=True): cells = cells.reshape(-1, sy, sx) return cells + def load_digits(fn): print 'loading "%s" ...' % fn digits_img = cv2.imread(fn, 0) @@ -58,23 +58,28 @@ def load_digits(fn): labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N) return digits, labels + def deskew(img): m = cv2.moments(img) if abs(m['mu02']) < 1e-2: return img.copy() skew = m['mu11']/m['mu02'] M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]]) - img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) + img = cv2.warpAffine(img, M, (SZ, SZ), + flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) return img + class StatModel(object): def load(self, fn): self.model.load(fn) + def save(self, fn): self.model.save(fn) + class KNearest(StatModel): - def __init__(self, k = 3): + def __init__(self, k=3): self.k = k self.model = cv2.KNearest() @@ -83,42 +88,47 @@ class KNearest(StatModel): self.model.train(samples, responses) def predict(self, samples): - retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k) + retval, results, neigh_resp, dists = self.model.find_nearest(samples, + self.k) return results.ravel() + class SVM(StatModel): def __init__(self, kernel_type=cv2.SVM_RBF, C=1, gamma=0.5): - self.params = dict( kernel_type = kernel_type, - svm_type = cv2.SVM_C_SVC, - C = C, - gamma = gamma ) + self.params = dict(kernel_type=kernel_type, + svm_type=cv2.SVM_C_SVC, + C=C, + gamma=gamma) self.model = cv2.SVM() def train(self, samples, responses): self.model = cv2.SVM() - self.model.train(samples, responses, params = self.params) + self.model.train(samples, responses, params=self.params) def train_auto(self, samples, responses): self.model = cv2.SVM() - self.model.train_auto(samples, responses, None, None, params = self.params) + self.model.train_auto(samples, responses, None, None, + params=self.params) def predict(self, samples): return self.model.predict_all(samples).ravel() + class RForest(StatModel): def __init__(self): - self.params = dict( max_depth=20 ) + self.params = dict(max_depth=20) self.model = cv2.RTrees() def train(self, samples, responses): self.model = cv2.RTrees() self.model.train(samples, cv2.CV_ROW_SAMPLE, responses, - params=self.params) + params=self.params) def predict(self, samples): predictions = map(self.model.predict, samples) return predictions + def evaluate_model(model, digits, samples, labels): resp = model.predict(samples) err = (labels != resp).mean() @@ -135,13 +145,15 @@ def evaluate_model(model, digits, samples, labels): for img, flag in zip(digits, resp == labels): img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if not flag: - img[...,:2] = 0 + img[..., :2] = 0 vis.append(img) return mosaic(25, vis) + def preprocess_simple(digits): return np.float32(digits).reshape(-1, SZ*SZ) / 255.0 + def preprocess_hog(digits): samples = [] for img in digits: @@ -150,9 +162,10 @@ def preprocess_hog(digits): mag, ang = cv2.cartToPolar(gx, gy) bin_n = 16 bin = np.int32(bin_n*ang/(2*np.pi)) - bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:] - mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:] - hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)] + bin_cells = bin[:10, :10], bin[10:, :10], bin[:10, 10:], bin[10:, 10:] + mag_cells = mag[:10, :10], mag[10:, :10], mag[:10, 10:], mag[10:, 10:] + hists = [np.bincount(b.ravel(), m.ravel(), bin_n) + for b, m in zip(bin_cells, mag_cells)] hist = np.hstack(hists) # transform to Hellinger kernel @@ -177,8 +190,10 @@ if __name__ == '__main__': digits, labels = digits[shuffle], labels[shuffle] digits2 = map(deskew, digits) - samples = preprocess_simple(digits2) - #samples = preprocess_hog(digits2) + if SIMPLE: + samples = preprocess_simple(digits2) + else: + samples = preprocess_hog(digits2) train_n = int(0.9*len(samples)) cv2.imshow('test set', mosaic(25, digits[train_n:])) @@ -201,17 +216,16 @@ if __name__ == '__main__': print 'training SVM...' - # HOG (original digits.py) - #model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383) - #model.train(samples_train, labels_train) - # Simple (cross-validation) - model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1) - model.train(samples_train, labels_train) + if SIMPLE: + model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1) + model.train(samples_train, labels_train) + else: + model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383) + model.train(samples_train, labels_train) vis = evaluate_model(model, digits_test, samples_test, labels_test) cv2.imshow('SVM test', vis) print 'saving SVM as "digits_svm.dat"...' model.save('digits_svm.dat') - cv2.waitKey(0)