Home | History | Annotate | Download | only in python2
      1 #!/usr/bin/env python
      2 
      3 '''
      4 SVM and KNearest digit recognition.
      5 
      6 Sample loads a dataset of handwritten digits from '../data/digits.png'.
      7 Then it trains a SVM and KNearest classifiers on it and evaluates
      8 their accuracy.
      9 
     10 Following preprocessing is applied to the dataset:
     11  - Moment-based image deskew (see deskew())
     12  - Digit images are split into 4 10x10 cells and 16-bin
     13    histogram of oriented gradients is computed for each
     14    cell
     15  - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
     16 
     17 
     18 [1] R. Arandjelovic, A. Zisserman
     19     "Three things everyone should know to improve object retrieval"
     20     http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
     21 
     22 Usage:
     23    digits.py
     24 '''
     25 
     26 # built-in modules
     27 from multiprocessing.pool import ThreadPool
     28 
     29 import cv2
     30 
     31 import numpy as np
     32 from numpy.linalg import norm
     33 
     34 # local modules
     35 from common import clock, mosaic
     36 
     37 
     38 
     39 SZ = 20 # size of each digit is SZ x SZ
     40 CLASS_N = 10
     41 DIGITS_FN = '../data/digits.png'
     42 
     43 def split2d(img, cell_size, flatten=True):
     44     h, w = img.shape[:2]
     45     sx, sy = cell_size
     46     cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
     47     cells = np.array(cells)
     48     if flatten:
     49         cells = cells.reshape(-1, sy, sx)
     50     return cells
     51 
     52 def load_digits(fn):
     53     print 'loading "%s" ...' % fn
     54     digits_img = cv2.imread(fn, 0)
     55     digits = split2d(digits_img, (SZ, SZ))
     56     labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
     57     return digits, labels
     58 
     59 def deskew(img):
     60     m = cv2.moments(img)
     61     if abs(m['mu02']) < 1e-2:
     62         return img.copy()
     63     skew = m['mu11']/m['mu02']
     64     M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
     65     img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
     66     return img
     67 
     68 class StatModel(object):
     69     def load(self, fn):
     70         self.model.load(fn)
     71     def save(self, fn):
     72         self.model.save(fn)
     73 
     74 class KNearest(StatModel):
     75     def __init__(self, k = 3):
     76         self.k = k
     77         self.model = cv2.ml.KNearest_create()
     78 
     79     def train(self, samples, responses):
     80         self.model = cv2.ml.KNearest_create()
     81         self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
     82 
     83     def predict(self, samples):
     84         retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
     85         return results.ravel()
     86 
     87 class SVM(StatModel):
     88     def __init__(self, C = 1, gamma = 0.5):
     89         self.model = cv2.ml.SVM_create()
     90         self.model.setGamma(gamma)
     91         self.model.setC(C)
     92         self.model.setKernel(cv2.ml.SVM_RBF)
     93         self.model.setType(cv2.ml.SVM_C_SVC)
     94 
     95     def train(self, samples, responses):
     96         self.model = cv2.ml.SVM_create()
     97         self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
     98 
     99     def predict(self, samples):
    100         return self.model.predict(samples)[1][0].ravel()
    101 
    102 
    103 def evaluate_model(model, digits, samples, labels):
    104     resp = model.predict(samples)
    105     err = (labels != resp).mean()
    106     print 'error: %.2f %%' % (err*100)
    107 
    108     confusion = np.zeros((10, 10), np.int32)
    109     for i, j in zip(labels, resp):
    110         confusion[i, j] += 1
    111     print 'confusion matrix:'
    112     print confusion
    113     print
    114 
    115     vis = []
    116     for img, flag in zip(digits, resp == labels):
    117         img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    118         if not flag:
    119             img[...,:2] = 0
    120         vis.append(img)
    121     return mosaic(25, vis)
    122 
    123 def preprocess_simple(digits):
    124     return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
    125 
    126 def preprocess_hog(digits):
    127     samples = []
    128     for img in digits:
    129         gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
    130         gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
    131         mag, ang = cv2.cartToPolar(gx, gy)
    132         bin_n = 16
    133         bin = np.int32(bin_n*ang/(2*np.pi))
    134         bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
    135         mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
    136         hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
    137         hist = np.hstack(hists)
    138 
    139         # transform to Hellinger kernel
    140         eps = 1e-7
    141         hist /= hist.sum() + eps
    142         hist = np.sqrt(hist)
    143         hist /= norm(hist) + eps
    144 
    145         samples.append(hist)
    146     return np.float32(samples)
    147 
    148 
    149 if __name__ == '__main__':
    150     print __doc__
    151 
    152     digits, labels = load_digits(DIGITS_FN)
    153 
    154     print 'preprocessing...'
    155     # shuffle digits
    156     rand = np.random.RandomState(321)
    157     shuffle = rand.permutation(len(digits))
    158     digits, labels = digits[shuffle], labels[shuffle]
    159 
    160     digits2 = map(deskew, digits)
    161     samples = preprocess_hog(digits2)
    162 
    163     train_n = int(0.9*len(samples))
    164     cv2.imshow('test set', mosaic(25, digits[train_n:]))
    165     digits_train, digits_test = np.split(digits2, [train_n])
    166     samples_train, samples_test = np.split(samples, [train_n])
    167     labels_train, labels_test = np.split(labels, [train_n])
    168 
    169 
    170     print 'training KNearest...'
    171     model = KNearest(k=4)
    172     model.train(samples_train, labels_train)
    173     vis = evaluate_model(model, digits_test, samples_test, labels_test)
    174     cv2.imshow('KNearest test', vis)
    175 
    176     print 'training SVM...'
    177     model = SVM(C=2.67, gamma=5.383)
    178     model.train(samples_train, labels_train)
    179     vis = evaluate_model(model, digits_test, samples_test, labels_test)
    180     cv2.imshow('SVM test', vis)
    181     print 'saving SVM as "digits_svm.dat"...'
    182     model.save('digits_svm.dat')
    183 
    184     cv2.waitKey(0)
    185