Experiment with the OpenCV digits.py example
This commit is contained in:
		
							parent
							
								
									16b142d94c
								
							
						
					
					
						commit
						47ba4b1b8a
					
				
					 4 changed files with 418 additions and 0 deletions
				
			
		
							
								
								
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
					@ -10,3 +10,6 @@ SVMTest.png
 | 
				
			||||||
DisplayImage
 | 
					DisplayImage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SharpenImage
 | 
					SharpenImage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					*.pyc
 | 
				
			||||||
 | 
					digits_svm.dat
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										220
									
								
								common.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										220
									
								
								common.py
									
										
									
									
									
										Executable file
									
								
							| 
						 | 
					@ -0,0 +1,220 @@
 | 
				
			||||||
 | 
					#!/usr/bin/env python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					This module contains some common routines used by other samples.
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# built-in modules
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import itertools as it
 | 
				
			||||||
 | 
					from contextlib import contextmanager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					image_extensions = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.pbm', '.pgm', '.ppm']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Bunch(object):
 | 
				
			||||||
 | 
					    def __init__(self, **kw):
 | 
				
			||||||
 | 
					        self.__dict__.update(kw)
 | 
				
			||||||
 | 
					    def __str__(self):
 | 
				
			||||||
 | 
					        return str(self.__dict__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def splitfn(fn):
 | 
				
			||||||
 | 
					    path, fn = os.path.split(fn)
 | 
				
			||||||
 | 
					    name, ext = os.path.splitext(fn)
 | 
				
			||||||
 | 
					    return path, name, ext
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def anorm2(a):
 | 
				
			||||||
 | 
					    return (a*a).sum(-1)
 | 
				
			||||||
 | 
					def anorm(a):
 | 
				
			||||||
 | 
					    return np.sqrt( anorm2(a) )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def homotrans(H, x, y):
 | 
				
			||||||
 | 
					    xs = H[0, 0]*x + H[0, 1]*y + H[0, 2]
 | 
				
			||||||
 | 
					    ys = H[1, 0]*x + H[1, 1]*y + H[1, 2]
 | 
				
			||||||
 | 
					    s  = H[2, 0]*x + H[2, 1]*y + H[2, 2]
 | 
				
			||||||
 | 
					    return xs/s, ys/s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def to_rect(a):
 | 
				
			||||||
 | 
					    a = np.ravel(a)
 | 
				
			||||||
 | 
					    if len(a) == 2:
 | 
				
			||||||
 | 
					        a = (0, 0, a[0], a[1])
 | 
				
			||||||
 | 
					    return np.array(a, np.float64).reshape(2, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rect2rect_mtx(src, dst):
 | 
				
			||||||
 | 
					    src, dst = to_rect(src), to_rect(dst)
 | 
				
			||||||
 | 
					    cx, cy = (dst[1] - dst[0]) / (src[1] - src[0])
 | 
				
			||||||
 | 
					    tx, ty = dst[0] - src[0] * (cx, cy)
 | 
				
			||||||
 | 
					    M = np.float64([[ cx,  0, tx],
 | 
				
			||||||
 | 
					                    [  0, cy, ty],
 | 
				
			||||||
 | 
					                    [  0,  0,  1]])
 | 
				
			||||||
 | 
					    return M
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def lookat(eye, target, up = (0, 0, 1)):
 | 
				
			||||||
 | 
					    fwd = np.asarray(target, np.float64) - eye
 | 
				
			||||||
 | 
					    fwd /= anorm(fwd)
 | 
				
			||||||
 | 
					    right = np.cross(fwd, up)
 | 
				
			||||||
 | 
					    right /= anorm(right)
 | 
				
			||||||
 | 
					    down = np.cross(fwd, right)
 | 
				
			||||||
 | 
					    R = np.float64([right, down, fwd])
 | 
				
			||||||
 | 
					    tvec = -np.dot(R, eye)
 | 
				
			||||||
 | 
					    return R, tvec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mtx2rvec(R):
 | 
				
			||||||
 | 
					    w, u, vt = cv2.SVDecomp(R - np.eye(3))
 | 
				
			||||||
 | 
					    p = vt[0] + u[:,0]*w[0]    # same as np.dot(R, vt[0])
 | 
				
			||||||
 | 
					    c = np.dot(vt[0], p)
 | 
				
			||||||
 | 
					    s = np.dot(vt[1], p)
 | 
				
			||||||
 | 
					    axis = np.cross(vt[0], vt[1])
 | 
				
			||||||
 | 
					    return axis * np.arctan2(s, c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def draw_str(dst, (x, y), s):
 | 
				
			||||||
 | 
					    cv2.putText(dst, s, (x+1, y+1), cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness = 2, lineType=cv2.LINE_AA)
 | 
				
			||||||
 | 
					    cv2.putText(dst, s, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv2.LINE_AA)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Sketcher:
 | 
				
			||||||
 | 
					    def __init__(self, windowname, dests, colors_func):
 | 
				
			||||||
 | 
					        self.prev_pt = None
 | 
				
			||||||
 | 
					        self.windowname = windowname
 | 
				
			||||||
 | 
					        self.dests = dests
 | 
				
			||||||
 | 
					        self.colors_func = colors_func
 | 
				
			||||||
 | 
					        self.dirty = False
 | 
				
			||||||
 | 
					        self.show()
 | 
				
			||||||
 | 
					        cv2.setMouseCallback(self.windowname, self.on_mouse)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def show(self):
 | 
				
			||||||
 | 
					        cv2.imshow(self.windowname, self.dests[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_mouse(self, event, x, y, flags, param):
 | 
				
			||||||
 | 
					        pt = (x, y)
 | 
				
			||||||
 | 
					        if event == cv2.EVENT_LBUTTONDOWN:
 | 
				
			||||||
 | 
					            self.prev_pt = pt
 | 
				
			||||||
 | 
					        elif event == cv2.EVENT_LBUTTONUP:
 | 
				
			||||||
 | 
					            self.prev_pt = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.prev_pt and flags & cv2.EVENT_FLAG_LBUTTON:
 | 
				
			||||||
 | 
					            for dst, color in zip(self.dests, self.colors_func()):
 | 
				
			||||||
 | 
					                cv2.line(dst, self.prev_pt, pt, color, 5)
 | 
				
			||||||
 | 
					            self.dirty = True
 | 
				
			||||||
 | 
					            self.prev_pt = pt
 | 
				
			||||||
 | 
					            self.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# palette data from matplotlib/_cm.py
 | 
				
			||||||
 | 
					_jet_data =   {'red':   ((0., 0, 0), (0.35, 0, 0), (0.66, 1, 1), (0.89,1, 1),
 | 
				
			||||||
 | 
					                         (1, 0.5, 0.5)),
 | 
				
			||||||
 | 
					               'green': ((0., 0, 0), (0.125,0, 0), (0.375,1, 1), (0.64,1, 1),
 | 
				
			||||||
 | 
					                         (0.91,0,0), (1, 0, 0)),
 | 
				
			||||||
 | 
					               'blue':  ((0., 0.5, 0.5), (0.11, 1, 1), (0.34, 1, 1), (0.65,0, 0),
 | 
				
			||||||
 | 
					                         (1, 0, 0))}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cmap_data = { 'jet' : _jet_data }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_cmap(name, n=256):
 | 
				
			||||||
 | 
					    data = cmap_data[name]
 | 
				
			||||||
 | 
					    xs = np.linspace(0.0, 1.0, n)
 | 
				
			||||||
 | 
					    channels = []
 | 
				
			||||||
 | 
					    eps = 1e-6
 | 
				
			||||||
 | 
					    for ch_name in ['blue', 'green', 'red']:
 | 
				
			||||||
 | 
					        ch_data = data[ch_name]
 | 
				
			||||||
 | 
					        xp, yp = [], []
 | 
				
			||||||
 | 
					        for x, y1, y2 in ch_data:
 | 
				
			||||||
 | 
					            xp += [x, x+eps]
 | 
				
			||||||
 | 
					            yp += [y1, y2]
 | 
				
			||||||
 | 
					        ch = np.interp(xs, xp, yp)
 | 
				
			||||||
 | 
					        channels.append(ch)
 | 
				
			||||||
 | 
					    return np.uint8(np.array(channels).T*255)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def nothing(*arg, **kw):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def clock():
 | 
				
			||||||
 | 
					    return cv2.getTickCount() / cv2.getTickFrequency()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@contextmanager
 | 
				
			||||||
 | 
					def Timer(msg):
 | 
				
			||||||
 | 
					    print msg, '...',
 | 
				
			||||||
 | 
					    start = clock()
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        yield
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        print "%.2f ms" % ((clock()-start)*1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StatValue:
 | 
				
			||||||
 | 
					    def __init__(self, smooth_coef = 0.5):
 | 
				
			||||||
 | 
					        self.value = None
 | 
				
			||||||
 | 
					        self.smooth_coef = smooth_coef
 | 
				
			||||||
 | 
					    def update(self, v):
 | 
				
			||||||
 | 
					        if self.value is None:
 | 
				
			||||||
 | 
					            self.value = v
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            c = self.smooth_coef
 | 
				
			||||||
 | 
					            self.value = c * self.value + (1.0-c) * v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RectSelector:
 | 
				
			||||||
 | 
					    def __init__(self, win, callback):
 | 
				
			||||||
 | 
					        self.win = win
 | 
				
			||||||
 | 
					        self.callback = callback
 | 
				
			||||||
 | 
					        cv2.setMouseCallback(win, self.onmouse)
 | 
				
			||||||
 | 
					        self.drag_start = None
 | 
				
			||||||
 | 
					        self.drag_rect = None
 | 
				
			||||||
 | 
					    def onmouse(self, event, x, y, flags, param):
 | 
				
			||||||
 | 
					        x, y = np.int16([x, y]) # BUG
 | 
				
			||||||
 | 
					        if event == cv2.EVENT_LBUTTONDOWN:
 | 
				
			||||||
 | 
					            self.drag_start = (x, y)
 | 
				
			||||||
 | 
					        if self.drag_start:
 | 
				
			||||||
 | 
					            if flags & cv2.EVENT_FLAG_LBUTTON:
 | 
				
			||||||
 | 
					                xo, yo = self.drag_start
 | 
				
			||||||
 | 
					                x0, y0 = np.minimum([xo, yo], [x, y])
 | 
				
			||||||
 | 
					                x1, y1 = np.maximum([xo, yo], [x, y])
 | 
				
			||||||
 | 
					                self.drag_rect = None
 | 
				
			||||||
 | 
					                if x1-x0 > 0 and y1-y0 > 0:
 | 
				
			||||||
 | 
					                    self.drag_rect = (x0, y0, x1, y1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                rect = self.drag_rect
 | 
				
			||||||
 | 
					                self.drag_start = None
 | 
				
			||||||
 | 
					                self.drag_rect = None
 | 
				
			||||||
 | 
					                if rect:
 | 
				
			||||||
 | 
					                    self.callback(rect)
 | 
				
			||||||
 | 
					    def draw(self, vis):
 | 
				
			||||||
 | 
					        if not self.drag_rect:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        x0, y0, x1, y1 = self.drag_rect
 | 
				
			||||||
 | 
					        cv2.rectangle(vis, (x0, y0), (x1, y1), (0, 255, 0), 2)
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def dragging(self):
 | 
				
			||||||
 | 
					        return self.drag_rect is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def grouper(n, iterable, fillvalue=None):
 | 
				
			||||||
 | 
					    '''grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx'''
 | 
				
			||||||
 | 
					    args = [iter(iterable)] * n
 | 
				
			||||||
 | 
					    return it.izip_longest(fillvalue=fillvalue, *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mosaic(w, imgs):
 | 
				
			||||||
 | 
					    '''Make a grid from images.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    w    -- number of grid columns
 | 
				
			||||||
 | 
					    imgs -- images (must have same size and format)
 | 
				
			||||||
 | 
					    '''
 | 
				
			||||||
 | 
					    imgs = iter(imgs)
 | 
				
			||||||
 | 
					    img0 = imgs.next()
 | 
				
			||||||
 | 
					    pad = np.zeros_like(img0)
 | 
				
			||||||
 | 
					    imgs = it.chain([img0], imgs)
 | 
				
			||||||
 | 
					    rows = grouper(w, imgs, pad)
 | 
				
			||||||
 | 
					    return np.vstack(map(np.hstack, rows))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def getsize(img):
 | 
				
			||||||
 | 
					    h, w = img.shape[:2]
 | 
				
			||||||
 | 
					    return w, h
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mdot(*args):
 | 
				
			||||||
 | 
					    return reduce(np.dot, args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def draw_keypoints(vis, keypoints, color = (0, 255, 255)):
 | 
				
			||||||
 | 
					    for kp in keypoints:
 | 
				
			||||||
 | 
					            x, y = kp.pt
 | 
				
			||||||
 | 
					            cv2.circle(vis, (int(x), int(y)), 2, color)
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								digits.png
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								digits.png
									
										
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 704 KiB  | 
							
								
								
									
										195
									
								
								digits.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										195
									
								
								digits.py
									
										
									
									
									
										Executable file
									
								
							| 
						 | 
					@ -0,0 +1,195 @@
 | 
				
			||||||
 | 
					#!/usr/bin/env python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					SVM and KNearest digit recognition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Sample loads a dataset of handwritten digits from '../data/digits.png'.
 | 
				
			||||||
 | 
					Then it trains a SVM and KNearest classifiers on it and evaluates
 | 
				
			||||||
 | 
					their accuracy.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Following preprocessing is applied to the dataset:
 | 
				
			||||||
 | 
					 - Moment-based image deskew (see deskew())
 | 
				
			||||||
 | 
					 - Digit images are split into 4 10x10 cells and 16-bin
 | 
				
			||||||
 | 
					   histogram of oriented gradients is computed for each
 | 
				
			||||||
 | 
					   cell
 | 
				
			||||||
 | 
					 - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[1] R. Arandjelovic, A. Zisserman
 | 
				
			||||||
 | 
					    "Three things everyone should know to improve object retrieval"
 | 
				
			||||||
 | 
					    http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Usage:
 | 
				
			||||||
 | 
					   digits.py
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# built-in modules
 | 
				
			||||||
 | 
					from multiprocessing.pool import ThreadPool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from numpy.linalg import norm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# local modules
 | 
				
			||||||
 | 
					from common import clock, mosaic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SZ = 20 # size of each digit is SZ x SZ
 | 
				
			||||||
 | 
					CLASS_N = 10
 | 
				
			||||||
 | 
					DIGITS_FN = 'digits.png'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def split2d(img, cell_size, flatten=True):
 | 
				
			||||||
 | 
					    h, w = img.shape[:2]
 | 
				
			||||||
 | 
					    sx, sy = cell_size
 | 
				
			||||||
 | 
					    cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
 | 
				
			||||||
 | 
					    cells = np.array(cells)
 | 
				
			||||||
 | 
					    if flatten:
 | 
				
			||||||
 | 
					        cells = cells.reshape(-1, sy, sx)
 | 
				
			||||||
 | 
					    return cells
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_digits(fn):
 | 
				
			||||||
 | 
					    print 'loading "%s" ...' % fn
 | 
				
			||||||
 | 
					    digits_img = cv2.imread(fn, 0)
 | 
				
			||||||
 | 
					    digits = split2d(digits_img, (SZ, SZ))
 | 
				
			||||||
 | 
					    labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
 | 
				
			||||||
 | 
					    return digits, labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def deskew(img):
 | 
				
			||||||
 | 
					    m = cv2.moments(img)
 | 
				
			||||||
 | 
					    if abs(m['mu02']) < 1e-2:
 | 
				
			||||||
 | 
					        return img.copy()
 | 
				
			||||||
 | 
					    skew = m['mu11']/m['mu02']
 | 
				
			||||||
 | 
					    M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
 | 
				
			||||||
 | 
					    img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
 | 
				
			||||||
 | 
					    return img
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StatModel(object):
 | 
				
			||||||
 | 
					    def load(self, fn):
 | 
				
			||||||
 | 
					        self.model.load(fn)
 | 
				
			||||||
 | 
					    def save(self, fn):
 | 
				
			||||||
 | 
					        self.model.save(fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KNearest(StatModel):
 | 
				
			||||||
 | 
					    def __init__(self, k = 3):
 | 
				
			||||||
 | 
					        self.k = k
 | 
				
			||||||
 | 
					        self.model = cv2.KNearest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train(self, samples, responses):
 | 
				
			||||||
 | 
					        self.model = cv2.KNearest()
 | 
				
			||||||
 | 
					        self.model.train(samples, responses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(self, samples):
 | 
				
			||||||
 | 
					        retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
 | 
				
			||||||
 | 
					        return results.ravel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SVM(StatModel):
 | 
				
			||||||
 | 
					    def __init__(self, kernel_type=cv2.SVM_RBF, C=1, gamma=0.5):
 | 
				
			||||||
 | 
					        self.params = dict( kernel_type = kernel_type,
 | 
				
			||||||
 | 
					                            svm_type = cv2.SVM_C_SVC,
 | 
				
			||||||
 | 
					                            C = C,
 | 
				
			||||||
 | 
					                            gamma = gamma )
 | 
				
			||||||
 | 
					        self.model = cv2.SVM()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train(self, samples, responses):
 | 
				
			||||||
 | 
					        self.model = cv2.SVM()
 | 
				
			||||||
 | 
					        self.model.train(samples, responses, params = self.params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train_auto(self, samples, responses):
 | 
				
			||||||
 | 
					        self.model = cv2.SVM()
 | 
				
			||||||
 | 
					        self.model.train_auto(samples, responses, None, None, params = self.params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(self, samples):
 | 
				
			||||||
 | 
					        return self.model.predict_all(samples).ravel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def evaluate_model(model, digits, samples, labels):
 | 
				
			||||||
 | 
					    resp = model.predict(samples)
 | 
				
			||||||
 | 
					    err = (labels != resp).mean()
 | 
				
			||||||
 | 
					    print 'error: %.2f %%' % (err*100)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    confusion = np.zeros((10, 10), np.int32)
 | 
				
			||||||
 | 
					    for i, j in zip(labels, resp):
 | 
				
			||||||
 | 
					        confusion[i, j] += 1
 | 
				
			||||||
 | 
					    print 'confusion matrix:'
 | 
				
			||||||
 | 
					    print confusion
 | 
				
			||||||
 | 
					    print
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vis = []
 | 
				
			||||||
 | 
					    for img, flag in zip(digits, resp == labels):
 | 
				
			||||||
 | 
					        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 | 
				
			||||||
 | 
					        if not flag:
 | 
				
			||||||
 | 
					            img[...,:2] = 0
 | 
				
			||||||
 | 
					        vis.append(img)
 | 
				
			||||||
 | 
					    return mosaic(25, vis)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def preprocess_simple(digits):
 | 
				
			||||||
 | 
					    return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def preprocess_hog(digits):
 | 
				
			||||||
 | 
					    samples = []
 | 
				
			||||||
 | 
					    for img in digits:
 | 
				
			||||||
 | 
					        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
 | 
				
			||||||
 | 
					        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
 | 
				
			||||||
 | 
					        mag, ang = cv2.cartToPolar(gx, gy)
 | 
				
			||||||
 | 
					        bin_n = 16
 | 
				
			||||||
 | 
					        bin = np.int32(bin_n*ang/(2*np.pi))
 | 
				
			||||||
 | 
					        bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
 | 
				
			||||||
 | 
					        mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
 | 
				
			||||||
 | 
					        hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
 | 
				
			||||||
 | 
					        hist = np.hstack(hists)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # transform to Hellinger kernel
 | 
				
			||||||
 | 
					        eps = 1e-7
 | 
				
			||||||
 | 
					        hist /= hist.sum() + eps
 | 
				
			||||||
 | 
					        hist = np.sqrt(hist)
 | 
				
			||||||
 | 
					        hist /= norm(hist) + eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        samples.append(hist)
 | 
				
			||||||
 | 
					    return np.float32(samples)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    print __doc__
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    digits, labels = load_digits(DIGITS_FN)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print 'preprocessing...'
 | 
				
			||||||
 | 
					    # shuffle digits
 | 
				
			||||||
 | 
					    rand = np.random.RandomState(321)
 | 
				
			||||||
 | 
					    shuffle = rand.permutation(len(digits))
 | 
				
			||||||
 | 
					    digits, labels = digits[shuffle], labels[shuffle]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    digits2 = map(deskew, digits)
 | 
				
			||||||
 | 
					    samples = preprocess_simple(digits2)
 | 
				
			||||||
 | 
					    #samples = preprocess_hog(digits2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_n = int(0.9*len(samples))
 | 
				
			||||||
 | 
					    cv2.imshow('test set', mosaic(25, digits[train_n:]))
 | 
				
			||||||
 | 
					    digits_train, digits_test = np.split(digits2, [train_n])
 | 
				
			||||||
 | 
					    samples_train, samples_test = np.split(samples, [train_n])
 | 
				
			||||||
 | 
					    labels_train, labels_test = np.split(labels, [train_n])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print 'training KNearest...'
 | 
				
			||||||
 | 
					    model = KNearest(k=4)
 | 
				
			||||||
 | 
					    model.train(samples_train, labels_train)
 | 
				
			||||||
 | 
					    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 | 
				
			||||||
 | 
					    cv2.imshow('KNearest test', vis)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print 'training SVM...'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HOG (original digits.py)
 | 
				
			||||||
 | 
					    #model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383)
 | 
				
			||||||
 | 
					    #model.train(samples_train, labels_train)
 | 
				
			||||||
 | 
					    # Simple (cross-validation)
 | 
				
			||||||
 | 
					    model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1)
 | 
				
			||||||
 | 
					    model.train(samples_train, labels_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 | 
				
			||||||
 | 
					    cv2.imshow('SVM test', vis)
 | 
				
			||||||
 | 
					    print 'saving SVM as "digits_svm.dat"...'
 | 
				
			||||||
 | 
					    model.save('digits_svm.dat')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cv2.waitKey(0)
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue