Experiment with the OpenCV digits.py example
This commit is contained in:
		
							parent
							
								
									16b142d94c
								
							
						
					
					
						commit
						47ba4b1b8a
					
				
					 4 changed files with 418 additions and 0 deletions
				
			
		
							
								
								
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -10,3 +10,6 @@ SVMTest.png
 | 
			
		|||
DisplayImage
 | 
			
		||||
 | 
			
		||||
SharpenImage
 | 
			
		||||
 | 
			
		||||
*.pyc
 | 
			
		||||
digits_svm.dat
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										220
									
								
								common.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										220
									
								
								common.py
									
										
									
									
									
										Executable file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,220 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
This module contains some common routines used by other samples.
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import cv2
 | 
			
		||||
 | 
			
		||||
# built-in modules
 | 
			
		||||
import os
 | 
			
		||||
import itertools as it
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
 | 
			
		||||
image_extensions = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.pbm', '.pgm', '.ppm']
 | 
			
		||||
 | 
			
		||||
class Bunch(object):
 | 
			
		||||
    def __init__(self, **kw):
 | 
			
		||||
        self.__dict__.update(kw)
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.__dict__)
 | 
			
		||||
 | 
			
		||||
def splitfn(fn):
 | 
			
		||||
    path, fn = os.path.split(fn)
 | 
			
		||||
    name, ext = os.path.splitext(fn)
 | 
			
		||||
    return path, name, ext
 | 
			
		||||
 | 
			
		||||
def anorm2(a):
 | 
			
		||||
    return (a*a).sum(-1)
 | 
			
		||||
def anorm(a):
 | 
			
		||||
    return np.sqrt( anorm2(a) )
 | 
			
		||||
 | 
			
		||||
def homotrans(H, x, y):
 | 
			
		||||
    xs = H[0, 0]*x + H[0, 1]*y + H[0, 2]
 | 
			
		||||
    ys = H[1, 0]*x + H[1, 1]*y + H[1, 2]
 | 
			
		||||
    s  = H[2, 0]*x + H[2, 1]*y + H[2, 2]
 | 
			
		||||
    return xs/s, ys/s
 | 
			
		||||
 | 
			
		||||
def to_rect(a):
 | 
			
		||||
    a = np.ravel(a)
 | 
			
		||||
    if len(a) == 2:
 | 
			
		||||
        a = (0, 0, a[0], a[1])
 | 
			
		||||
    return np.array(a, np.float64).reshape(2, 2)
 | 
			
		||||
 | 
			
		||||
def rect2rect_mtx(src, dst):
 | 
			
		||||
    src, dst = to_rect(src), to_rect(dst)
 | 
			
		||||
    cx, cy = (dst[1] - dst[0]) / (src[1] - src[0])
 | 
			
		||||
    tx, ty = dst[0] - src[0] * (cx, cy)
 | 
			
		||||
    M = np.float64([[ cx,  0, tx],
 | 
			
		||||
                    [  0, cy, ty],
 | 
			
		||||
                    [  0,  0,  1]])
 | 
			
		||||
    return M
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lookat(eye, target, up = (0, 0, 1)):
 | 
			
		||||
    fwd = np.asarray(target, np.float64) - eye
 | 
			
		||||
    fwd /= anorm(fwd)
 | 
			
		||||
    right = np.cross(fwd, up)
 | 
			
		||||
    right /= anorm(right)
 | 
			
		||||
    down = np.cross(fwd, right)
 | 
			
		||||
    R = np.float64([right, down, fwd])
 | 
			
		||||
    tvec = -np.dot(R, eye)
 | 
			
		||||
    return R, tvec
 | 
			
		||||
 | 
			
		||||
def mtx2rvec(R):
 | 
			
		||||
    w, u, vt = cv2.SVDecomp(R - np.eye(3))
 | 
			
		||||
    p = vt[0] + u[:,0]*w[0]    # same as np.dot(R, vt[0])
 | 
			
		||||
    c = np.dot(vt[0], p)
 | 
			
		||||
    s = np.dot(vt[1], p)
 | 
			
		||||
    axis = np.cross(vt[0], vt[1])
 | 
			
		||||
    return axis * np.arctan2(s, c)
 | 
			
		||||
 | 
			
		||||
def draw_str(dst, (x, y), s):
 | 
			
		||||
    cv2.putText(dst, s, (x+1, y+1), cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness = 2, lineType=cv2.LINE_AA)
 | 
			
		||||
    cv2.putText(dst, s, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv2.LINE_AA)
 | 
			
		||||
 | 
			
		||||
class Sketcher:
 | 
			
		||||
    def __init__(self, windowname, dests, colors_func):
 | 
			
		||||
        self.prev_pt = None
 | 
			
		||||
        self.windowname = windowname
 | 
			
		||||
        self.dests = dests
 | 
			
		||||
        self.colors_func = colors_func
 | 
			
		||||
        self.dirty = False
 | 
			
		||||
        self.show()
 | 
			
		||||
        cv2.setMouseCallback(self.windowname, self.on_mouse)
 | 
			
		||||
 | 
			
		||||
    def show(self):
 | 
			
		||||
        cv2.imshow(self.windowname, self.dests[0])
 | 
			
		||||
 | 
			
		||||
    def on_mouse(self, event, x, y, flags, param):
 | 
			
		||||
        pt = (x, y)
 | 
			
		||||
        if event == cv2.EVENT_LBUTTONDOWN:
 | 
			
		||||
            self.prev_pt = pt
 | 
			
		||||
        elif event == cv2.EVENT_LBUTTONUP:
 | 
			
		||||
            self.prev_pt = None
 | 
			
		||||
 | 
			
		||||
        if self.prev_pt and flags & cv2.EVENT_FLAG_LBUTTON:
 | 
			
		||||
            for dst, color in zip(self.dests, self.colors_func()):
 | 
			
		||||
                cv2.line(dst, self.prev_pt, pt, color, 5)
 | 
			
		||||
            self.dirty = True
 | 
			
		||||
            self.prev_pt = pt
 | 
			
		||||
            self.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# palette data from matplotlib/_cm.py
 | 
			
		||||
_jet_data =   {'red':   ((0., 0, 0), (0.35, 0, 0), (0.66, 1, 1), (0.89,1, 1),
 | 
			
		||||
                         (1, 0.5, 0.5)),
 | 
			
		||||
               'green': ((0., 0, 0), (0.125,0, 0), (0.375,1, 1), (0.64,1, 1),
 | 
			
		||||
                         (0.91,0,0), (1, 0, 0)),
 | 
			
		||||
               'blue':  ((0., 0.5, 0.5), (0.11, 1, 1), (0.34, 1, 1), (0.65,0, 0),
 | 
			
		||||
                         (1, 0, 0))}
 | 
			
		||||
 | 
			
		||||
cmap_data = { 'jet' : _jet_data }
 | 
			
		||||
 | 
			
		||||
def make_cmap(name, n=256):
 | 
			
		||||
    data = cmap_data[name]
 | 
			
		||||
    xs = np.linspace(0.0, 1.0, n)
 | 
			
		||||
    channels = []
 | 
			
		||||
    eps = 1e-6
 | 
			
		||||
    for ch_name in ['blue', 'green', 'red']:
 | 
			
		||||
        ch_data = data[ch_name]
 | 
			
		||||
        xp, yp = [], []
 | 
			
		||||
        for x, y1, y2 in ch_data:
 | 
			
		||||
            xp += [x, x+eps]
 | 
			
		||||
            yp += [y1, y2]
 | 
			
		||||
        ch = np.interp(xs, xp, yp)
 | 
			
		||||
        channels.append(ch)
 | 
			
		||||
    return np.uint8(np.array(channels).T*255)
 | 
			
		||||
 | 
			
		||||
def nothing(*arg, **kw):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
def clock():
 | 
			
		||||
    return cv2.getTickCount() / cv2.getTickFrequency()
 | 
			
		||||
 | 
			
		||||
@contextmanager
 | 
			
		||||
def Timer(msg):
 | 
			
		||||
    print msg, '...',
 | 
			
		||||
    start = clock()
 | 
			
		||||
    try:
 | 
			
		||||
        yield
 | 
			
		||||
    finally:
 | 
			
		||||
        print "%.2f ms" % ((clock()-start)*1000)
 | 
			
		||||
 | 
			
		||||
class StatValue:
 | 
			
		||||
    def __init__(self, smooth_coef = 0.5):
 | 
			
		||||
        self.value = None
 | 
			
		||||
        self.smooth_coef = smooth_coef
 | 
			
		||||
    def update(self, v):
 | 
			
		||||
        if self.value is None:
 | 
			
		||||
            self.value = v
 | 
			
		||||
        else:
 | 
			
		||||
            c = self.smooth_coef
 | 
			
		||||
            self.value = c * self.value + (1.0-c) * v
 | 
			
		||||
 | 
			
		||||
class RectSelector:
 | 
			
		||||
    def __init__(self, win, callback):
 | 
			
		||||
        self.win = win
 | 
			
		||||
        self.callback = callback
 | 
			
		||||
        cv2.setMouseCallback(win, self.onmouse)
 | 
			
		||||
        self.drag_start = None
 | 
			
		||||
        self.drag_rect = None
 | 
			
		||||
    def onmouse(self, event, x, y, flags, param):
 | 
			
		||||
        x, y = np.int16([x, y]) # BUG
 | 
			
		||||
        if event == cv2.EVENT_LBUTTONDOWN:
 | 
			
		||||
            self.drag_start = (x, y)
 | 
			
		||||
        if self.drag_start:
 | 
			
		||||
            if flags & cv2.EVENT_FLAG_LBUTTON:
 | 
			
		||||
                xo, yo = self.drag_start
 | 
			
		||||
                x0, y0 = np.minimum([xo, yo], [x, y])
 | 
			
		||||
                x1, y1 = np.maximum([xo, yo], [x, y])
 | 
			
		||||
                self.drag_rect = None
 | 
			
		||||
                if x1-x0 > 0 and y1-y0 > 0:
 | 
			
		||||
                    self.drag_rect = (x0, y0, x1, y1)
 | 
			
		||||
            else:
 | 
			
		||||
                rect = self.drag_rect
 | 
			
		||||
                self.drag_start = None
 | 
			
		||||
                self.drag_rect = None
 | 
			
		||||
                if rect:
 | 
			
		||||
                    self.callback(rect)
 | 
			
		||||
    def draw(self, vis):
 | 
			
		||||
        if not self.drag_rect:
 | 
			
		||||
            return False
 | 
			
		||||
        x0, y0, x1, y1 = self.drag_rect
 | 
			
		||||
        cv2.rectangle(vis, (x0, y0), (x1, y1), (0, 255, 0), 2)
 | 
			
		||||
        return True
 | 
			
		||||
    @property
 | 
			
		||||
    def dragging(self):
 | 
			
		||||
        return self.drag_rect is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def grouper(n, iterable, fillvalue=None):
 | 
			
		||||
    '''grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx'''
 | 
			
		||||
    args = [iter(iterable)] * n
 | 
			
		||||
    return it.izip_longest(fillvalue=fillvalue, *args)
 | 
			
		||||
 | 
			
		||||
def mosaic(w, imgs):
 | 
			
		||||
    '''Make a grid from images.
 | 
			
		||||
 | 
			
		||||
    w    -- number of grid columns
 | 
			
		||||
    imgs -- images (must have same size and format)
 | 
			
		||||
    '''
 | 
			
		||||
    imgs = iter(imgs)
 | 
			
		||||
    img0 = imgs.next()
 | 
			
		||||
    pad = np.zeros_like(img0)
 | 
			
		||||
    imgs = it.chain([img0], imgs)
 | 
			
		||||
    rows = grouper(w, imgs, pad)
 | 
			
		||||
    return np.vstack(map(np.hstack, rows))
 | 
			
		||||
 | 
			
		||||
def getsize(img):
 | 
			
		||||
    h, w = img.shape[:2]
 | 
			
		||||
    return w, h
 | 
			
		||||
 | 
			
		||||
def mdot(*args):
 | 
			
		||||
    return reduce(np.dot, args)
 | 
			
		||||
 | 
			
		||||
def draw_keypoints(vis, keypoints, color = (0, 255, 255)):
 | 
			
		||||
    for kp in keypoints:
 | 
			
		||||
            x, y = kp.pt
 | 
			
		||||
            cv2.circle(vis, (int(x), int(y)), 2, color)
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								digits.png
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								digits.png
									
										
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 704 KiB  | 
							
								
								
									
										195
									
								
								digits.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										195
									
								
								digits.py
									
										
									
									
									
										Executable file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,195 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
SVM and KNearest digit recognition.
 | 
			
		||||
 | 
			
		||||
Sample loads a dataset of handwritten digits from '../data/digits.png'.
 | 
			
		||||
Then it trains a SVM and KNearest classifiers on it and evaluates
 | 
			
		||||
their accuracy.
 | 
			
		||||
 | 
			
		||||
Following preprocessing is applied to the dataset:
 | 
			
		||||
 - Moment-based image deskew (see deskew())
 | 
			
		||||
 - Digit images are split into 4 10x10 cells and 16-bin
 | 
			
		||||
   histogram of oriented gradients is computed for each
 | 
			
		||||
   cell
 | 
			
		||||
 - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[1] R. Arandjelovic, A. Zisserman
 | 
			
		||||
    "Three things everyone should know to improve object retrieval"
 | 
			
		||||
    http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
 | 
			
		||||
 | 
			
		||||
Usage:
 | 
			
		||||
   digits.py
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
# built-in modules
 | 
			
		||||
from multiprocessing.pool import ThreadPool
 | 
			
		||||
 | 
			
		||||
import cv2
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from numpy.linalg import norm
 | 
			
		||||
 | 
			
		||||
# local modules
 | 
			
		||||
from common import clock, mosaic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SZ = 20 # size of each digit is SZ x SZ
 | 
			
		||||
CLASS_N = 10
 | 
			
		||||
DIGITS_FN = 'digits.png'
 | 
			
		||||
 | 
			
		||||
def split2d(img, cell_size, flatten=True):
 | 
			
		||||
    h, w = img.shape[:2]
 | 
			
		||||
    sx, sy = cell_size
 | 
			
		||||
    cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
 | 
			
		||||
    cells = np.array(cells)
 | 
			
		||||
    if flatten:
 | 
			
		||||
        cells = cells.reshape(-1, sy, sx)
 | 
			
		||||
    return cells
 | 
			
		||||
 | 
			
		||||
def load_digits(fn):
 | 
			
		||||
    print 'loading "%s" ...' % fn
 | 
			
		||||
    digits_img = cv2.imread(fn, 0)
 | 
			
		||||
    digits = split2d(digits_img, (SZ, SZ))
 | 
			
		||||
    labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
 | 
			
		||||
    return digits, labels
 | 
			
		||||
 | 
			
		||||
def deskew(img):
 | 
			
		||||
    m = cv2.moments(img)
 | 
			
		||||
    if abs(m['mu02']) < 1e-2:
 | 
			
		||||
        return img.copy()
 | 
			
		||||
    skew = m['mu11']/m['mu02']
 | 
			
		||||
    M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
 | 
			
		||||
    img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
 | 
			
		||||
    return img
 | 
			
		||||
 | 
			
		||||
class StatModel(object):
 | 
			
		||||
    def load(self, fn):
 | 
			
		||||
        self.model.load(fn)
 | 
			
		||||
    def save(self, fn):
 | 
			
		||||
        self.model.save(fn)
 | 
			
		||||
 | 
			
		||||
class KNearest(StatModel):
 | 
			
		||||
    def __init__(self, k = 3):
 | 
			
		||||
        self.k = k
 | 
			
		||||
        self.model = cv2.KNearest()
 | 
			
		||||
 | 
			
		||||
    def train(self, samples, responses):
 | 
			
		||||
        self.model = cv2.KNearest()
 | 
			
		||||
        self.model.train(samples, responses)
 | 
			
		||||
 | 
			
		||||
    def predict(self, samples):
 | 
			
		||||
        retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
 | 
			
		||||
        return results.ravel()
 | 
			
		||||
 | 
			
		||||
class SVM(StatModel):
 | 
			
		||||
    def __init__(self, kernel_type=cv2.SVM_RBF, C=1, gamma=0.5):
 | 
			
		||||
        self.params = dict( kernel_type = kernel_type,
 | 
			
		||||
                            svm_type = cv2.SVM_C_SVC,
 | 
			
		||||
                            C = C,
 | 
			
		||||
                            gamma = gamma )
 | 
			
		||||
        self.model = cv2.SVM()
 | 
			
		||||
 | 
			
		||||
    def train(self, samples, responses):
 | 
			
		||||
        self.model = cv2.SVM()
 | 
			
		||||
        self.model.train(samples, responses, params = self.params)
 | 
			
		||||
 | 
			
		||||
    def train_auto(self, samples, responses):
 | 
			
		||||
        self.model = cv2.SVM()
 | 
			
		||||
        self.model.train_auto(samples, responses, None, None, params = self.params)
 | 
			
		||||
 | 
			
		||||
    def predict(self, samples):
 | 
			
		||||
        return self.model.predict_all(samples).ravel()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def evaluate_model(model, digits, samples, labels):
 | 
			
		||||
    resp = model.predict(samples)
 | 
			
		||||
    err = (labels != resp).mean()
 | 
			
		||||
    print 'error: %.2f %%' % (err*100)
 | 
			
		||||
 | 
			
		||||
    confusion = np.zeros((10, 10), np.int32)
 | 
			
		||||
    for i, j in zip(labels, resp):
 | 
			
		||||
        confusion[i, j] += 1
 | 
			
		||||
    print 'confusion matrix:'
 | 
			
		||||
    print confusion
 | 
			
		||||
    print
 | 
			
		||||
 | 
			
		||||
    vis = []
 | 
			
		||||
    for img, flag in zip(digits, resp == labels):
 | 
			
		||||
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 | 
			
		||||
        if not flag:
 | 
			
		||||
            img[...,:2] = 0
 | 
			
		||||
        vis.append(img)
 | 
			
		||||
    return mosaic(25, vis)
 | 
			
		||||
 | 
			
		||||
def preprocess_simple(digits):
 | 
			
		||||
    return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
 | 
			
		||||
 | 
			
		||||
def preprocess_hog(digits):
 | 
			
		||||
    samples = []
 | 
			
		||||
    for img in digits:
 | 
			
		||||
        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
 | 
			
		||||
        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
 | 
			
		||||
        mag, ang = cv2.cartToPolar(gx, gy)
 | 
			
		||||
        bin_n = 16
 | 
			
		||||
        bin = np.int32(bin_n*ang/(2*np.pi))
 | 
			
		||||
        bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
 | 
			
		||||
        mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
 | 
			
		||||
        hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
 | 
			
		||||
        hist = np.hstack(hists)
 | 
			
		||||
 | 
			
		||||
        # transform to Hellinger kernel
 | 
			
		||||
        eps = 1e-7
 | 
			
		||||
        hist /= hist.sum() + eps
 | 
			
		||||
        hist = np.sqrt(hist)
 | 
			
		||||
        hist /= norm(hist) + eps
 | 
			
		||||
 | 
			
		||||
        samples.append(hist)
 | 
			
		||||
    return np.float32(samples)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    print __doc__
 | 
			
		||||
 | 
			
		||||
    digits, labels = load_digits(DIGITS_FN)
 | 
			
		||||
 | 
			
		||||
    print 'preprocessing...'
 | 
			
		||||
    # shuffle digits
 | 
			
		||||
    rand = np.random.RandomState(321)
 | 
			
		||||
    shuffle = rand.permutation(len(digits))
 | 
			
		||||
    digits, labels = digits[shuffle], labels[shuffle]
 | 
			
		||||
 | 
			
		||||
    digits2 = map(deskew, digits)
 | 
			
		||||
    samples = preprocess_simple(digits2)
 | 
			
		||||
    #samples = preprocess_hog(digits2)
 | 
			
		||||
 | 
			
		||||
    train_n = int(0.9*len(samples))
 | 
			
		||||
    cv2.imshow('test set', mosaic(25, digits[train_n:]))
 | 
			
		||||
    digits_train, digits_test = np.split(digits2, [train_n])
 | 
			
		||||
    samples_train, samples_test = np.split(samples, [train_n])
 | 
			
		||||
    labels_train, labels_test = np.split(labels, [train_n])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    print 'training KNearest...'
 | 
			
		||||
    model = KNearest(k=4)
 | 
			
		||||
    model.train(samples_train, labels_train)
 | 
			
		||||
    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 | 
			
		||||
    cv2.imshow('KNearest test', vis)
 | 
			
		||||
 | 
			
		||||
    print 'training SVM...'
 | 
			
		||||
 | 
			
		||||
    # HOG (original digits.py)
 | 
			
		||||
    #model = SVM(kernel_type=cv2.SVM_RBF, C=2.67, gamma=5.383)
 | 
			
		||||
    #model.train(samples_train, labels_train)
 | 
			
		||||
    # Simple (cross-validation)
 | 
			
		||||
    model = SVM(kernel_type=cv2.SVM_LINEAR, C=0.1)
 | 
			
		||||
    model.train(samples_train, labels_train)
 | 
			
		||||
 | 
			
		||||
    vis = evaluate_model(model, digits_test, samples_test, labels_test)
 | 
			
		||||
    cv2.imshow('SVM test', vis)
 | 
			
		||||
    print 'saving SVM as "digits_svm.dat"...'
 | 
			
		||||
    model.save('digits_svm.dat')
 | 
			
		||||
 | 
			
		||||
    cv2.waitKey(0)
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue