Home | History | Annotate | Download | only in python2
      1 #!/usr/bin/env python
      2 
      3 '''
      4 MOSSE tracking sample
      5 
      6 This sample implements correlation-based tracking approach, described in [1].
      7 
      8 Usage:
      9   mosse.py [--pause] [<video source>]
     10 
     11   --pause  -  Start with playback paused at the first video frame.
     12               Useful for tracking target selection.
     13 
     14   Draw rectangles around objects with a mouse to track them.
     15 
     16 Keys:
     17   SPACE    - pause video
     18   c        - clear targets
     19 
     20 [1] David S. Bolme et al. "Visual Object Tracking using Adaptive Correlation Filters"
     21     http://www.cs.colostate.edu/~bolme/publications/Bolme2010Tracking.pdf
     22 '''
     23 
     24 import numpy as np
     25 import cv2
     26 from common import draw_str, RectSelector
     27 import video
     28 
     29 def rnd_warp(a):
     30     h, w = a.shape[:2]
     31     T = np.zeros((2, 3))
     32     coef = 0.2
     33     ang = (np.random.rand()-0.5)*coef
     34     c, s = np.cos(ang), np.sin(ang)
     35     T[:2, :2] = [[c,-s], [s, c]]
     36     T[:2, :2] += (np.random.rand(2, 2) - 0.5)*coef
     37     c = (w/2, h/2)
     38     T[:,2] = c - np.dot(T[:2, :2], c)
     39     return cv2.warpAffine(a, T, (w, h), borderMode = cv2.BORDER_REFLECT)
     40 
     41 def divSpec(A, B):
     42     Ar, Ai = A[...,0], A[...,1]
     43     Br, Bi = B[...,0], B[...,1]
     44     C = (Ar+1j*Ai)/(Br+1j*Bi)
     45     C = np.dstack([np.real(C), np.imag(C)]).copy()
     46     return C
     47 
     48 eps = 1e-5
     49 
     50 class MOSSE:
     51     def __init__(self, frame, rect):
     52         x1, y1, x2, y2 = rect
     53         w, h = map(cv2.getOptimalDFTSize, [x2-x1, y2-y1])
     54         x1, y1 = (x1+x2-w)//2, (y1+y2-h)//2
     55         self.pos = x, y = x1+0.5*(w-1), y1+0.5*(h-1)
     56         self.size = w, h
     57         img = cv2.getRectSubPix(frame, (w, h), (x, y))
     58 
     59         self.win = cv2.createHanningWindow((w, h), cv2.CV_32F)
     60         g = np.zeros((h, w), np.float32)
     61         g[h//2, w//2] = 1
     62         g = cv2.GaussianBlur(g, (-1, -1), 2.0)
     63         g /= g.max()
     64 
     65         self.G = cv2.dft(g, flags=cv2.DFT_COMPLEX_OUTPUT)
     66         self.H1 = np.zeros_like(self.G)
     67         self.H2 = np.zeros_like(self.G)
     68         for i in xrange(128):
     69             a = self.preprocess(rnd_warp(img))
     70             A = cv2.dft(a, flags=cv2.DFT_COMPLEX_OUTPUT)
     71             self.H1 += cv2.mulSpectrums(self.G, A, 0, conjB=True)
     72             self.H2 += cv2.mulSpectrums(     A, A, 0, conjB=True)
     73         self.update_kernel()
     74         self.update(frame)
     75 
     76     def update(self, frame, rate = 0.125):
     77         (x, y), (w, h) = self.pos, self.size
     78         self.last_img = img = cv2.getRectSubPix(frame, (w, h), (x, y))
     79         img = self.preprocess(img)
     80         self.last_resp, (dx, dy), self.psr = self.correlate(img)
     81         self.good = self.psr > 8.0
     82         if not self.good:
     83             return
     84 
     85         self.pos = x+dx, y+dy
     86         self.last_img = img = cv2.getRectSubPix(frame, (w, h), self.pos)
     87         img = self.preprocess(img)
     88 
     89         A = cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT)
     90         H1 = cv2.mulSpectrums(self.G, A, 0, conjB=True)
     91         H2 = cv2.mulSpectrums(     A, A, 0, conjB=True)
     92         self.H1 = self.H1 * (1.0-rate) + H1 * rate
     93         self.H2 = self.H2 * (1.0-rate) + H2 * rate
     94         self.update_kernel()
     95 
     96     @property
     97     def state_vis(self):
     98         f = cv2.idft(self.H, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT )
     99         h, w = f.shape
    100         f = np.roll(f, -h//2, 0)
    101         f = np.roll(f, -w//2, 1)
    102         kernel = np.uint8( (f-f.min()) / f.ptp()*255 )
    103         resp = self.last_resp
    104         resp = np.uint8(np.clip(resp/resp.max(), 0, 1)*255)
    105         vis = np.hstack([self.last_img, kernel, resp])
    106         return vis
    107 
    108     def draw_state(self, vis):
    109         (x, y), (w, h) = self.pos, self.size
    110         x1, y1, x2, y2 = int(x-0.5*w), int(y-0.5*h), int(x+0.5*w), int(y+0.5*h)
    111         cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255))
    112         if self.good:
    113             cv2.circle(vis, (int(x), int(y)), 2, (0, 0, 255), -1)
    114         else:
    115             cv2.line(vis, (x1, y1), (x2, y2), (0, 0, 255))
    116             cv2.line(vis, (x2, y1), (x1, y2), (0, 0, 255))
    117         draw_str(vis, (x1, y2+16), 'PSR: %.2f' % self.psr)
    118 
    119     def preprocess(self, img):
    120         img = np.log(np.float32(img)+1.0)
    121         img = (img-img.mean()) / (img.std()+eps)
    122         return img*self.win
    123 
    124     def correlate(self, img):
    125         C = cv2.mulSpectrums(cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT), self.H, 0, conjB=True)
    126         resp = cv2.idft(C, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT)
    127         h, w = resp.shape
    128         _, mval, _, (mx, my) = cv2.minMaxLoc(resp)
    129         side_resp = resp.copy()
    130         cv2.rectangle(side_resp, (mx-5, my-5), (mx+5, my+5), 0, -1)
    131         smean, sstd = side_resp.mean(), side_resp.std()
    132         psr = (mval-smean) / (sstd+eps)
    133         return resp, (mx-w//2, my-h//2), psr
    134 
    135     def update_kernel(self):
    136         self.H = divSpec(self.H1, self.H2)
    137         self.H[...,1] *= -1
    138 
    139 class App:
    140     def __init__(self, video_src, paused = False):
    141         self.cap = video.create_capture(video_src)
    142         _, self.frame = self.cap.read()
    143         cv2.imshow('frame', self.frame)
    144         self.rect_sel = RectSelector('frame', self.onrect)
    145         self.trackers = []
    146         self.paused = paused
    147 
    148     def onrect(self, rect):
    149         frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
    150         tracker = MOSSE(frame_gray, rect)
    151         self.trackers.append(tracker)
    152 
    153     def run(self):
    154         while True:
    155             if not self.paused:
    156                 ret, self.frame = self.cap.read()
    157                 if not ret:
    158                     break
    159                 frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY)
    160                 for tracker in self.trackers:
    161                     tracker.update(frame_gray)
    162 
    163             vis = self.frame.copy()
    164             for tracker in self.trackers:
    165                 tracker.draw_state(vis)
    166             if len(self.trackers) > 0:
    167                 cv2.imshow('tracker state', self.trackers[-1].state_vis)
    168             self.rect_sel.draw(vis)
    169 
    170             cv2.imshow('frame', vis)
    171             ch = cv2.waitKey(10) & 0xFF
    172             if ch == 27:
    173                 break
    174             if ch == ord(' '):
    175                 self.paused = not self.paused
    176             if ch == ord('c'):
    177                 self.trackers = []
    178 
    179 
    180 if __name__ == '__main__':
    181     print __doc__
    182     import sys, getopt
    183     opts, args = getopt.getopt(sys.argv[1:], '', ['pause'])
    184     opts = dict(opts)
    185     try:
    186         video_src = args[0]
    187     except:
    188         video_src = '0'
    189 
    190     App(video_src, paused = '--pause' in opts).run()
    191