1 #!/usr/bin/env python 2 3 ''' 4 MOSSE tracking sample 5 6 This sample implements correlation-based tracking approach, described in [1]. 7 8 Usage: 9 mosse.py [--pause] [<video source>] 10 11 --pause - Start with playback paused at the first video frame. 12 Useful for tracking target selection. 13 14 Draw rectangles around objects with a mouse to track them. 15 16 Keys: 17 SPACE - pause video 18 c - clear targets 19 20 [1] David S. Bolme et al. "Visual Object Tracking using Adaptive Correlation Filters" 21 http://www.cs.colostate.edu/~bolme/publications/Bolme2010Tracking.pdf 22 ''' 23 24 import numpy as np 25 import cv2 26 from common import draw_str, RectSelector 27 import video 28 29 def rnd_warp(a): 30 h, w = a.shape[:2] 31 T = np.zeros((2, 3)) 32 coef = 0.2 33 ang = (np.random.rand()-0.5)*coef 34 c, s = np.cos(ang), np.sin(ang) 35 T[:2, :2] = [[c,-s], [s, c]] 36 T[:2, :2] += (np.random.rand(2, 2) - 0.5)*coef 37 c = (w/2, h/2) 38 T[:,2] = c - np.dot(T[:2, :2], c) 39 return cv2.warpAffine(a, T, (w, h), borderMode = cv2.BORDER_REFLECT) 40 41 def divSpec(A, B): 42 Ar, Ai = A[...,0], A[...,1] 43 Br, Bi = B[...,0], B[...,1] 44 C = (Ar+1j*Ai)/(Br+1j*Bi) 45 C = np.dstack([np.real(C), np.imag(C)]).copy() 46 return C 47 48 eps = 1e-5 49 50 class MOSSE: 51 def __init__(self, frame, rect): 52 x1, y1, x2, y2 = rect 53 w, h = map(cv2.getOptimalDFTSize, [x2-x1, y2-y1]) 54 x1, y1 = (x1+x2-w)//2, (y1+y2-h)//2 55 self.pos = x, y = x1+0.5*(w-1), y1+0.5*(h-1) 56 self.size = w, h 57 img = cv2.getRectSubPix(frame, (w, h), (x, y)) 58 59 self.win = cv2.createHanningWindow((w, h), cv2.CV_32F) 60 g = np.zeros((h, w), np.float32) 61 g[h//2, w//2] = 1 62 g = cv2.GaussianBlur(g, (-1, -1), 2.0) 63 g /= g.max() 64 65 self.G = cv2.dft(g, flags=cv2.DFT_COMPLEX_OUTPUT) 66 self.H1 = np.zeros_like(self.G) 67 self.H2 = np.zeros_like(self.G) 68 for i in xrange(128): 69 a = self.preprocess(rnd_warp(img)) 70 A = cv2.dft(a, flags=cv2.DFT_COMPLEX_OUTPUT) 71 self.H1 += cv2.mulSpectrums(self.G, A, 0, conjB=True) 72 self.H2 += cv2.mulSpectrums( A, A, 0, conjB=True) 73 self.update_kernel() 74 self.update(frame) 75 76 def update(self, frame, rate = 0.125): 77 (x, y), (w, h) = self.pos, self.size 78 self.last_img = img = cv2.getRectSubPix(frame, (w, h), (x, y)) 79 img = self.preprocess(img) 80 self.last_resp, (dx, dy), self.psr = self.correlate(img) 81 self.good = self.psr > 8.0 82 if not self.good: 83 return 84 85 self.pos = x+dx, y+dy 86 self.last_img = img = cv2.getRectSubPix(frame, (w, h), self.pos) 87 img = self.preprocess(img) 88 89 A = cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT) 90 H1 = cv2.mulSpectrums(self.G, A, 0, conjB=True) 91 H2 = cv2.mulSpectrums( A, A, 0, conjB=True) 92 self.H1 = self.H1 * (1.0-rate) + H1 * rate 93 self.H2 = self.H2 * (1.0-rate) + H2 * rate 94 self.update_kernel() 95 96 @property 97 def state_vis(self): 98 f = cv2.idft(self.H, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT ) 99 h, w = f.shape 100 f = np.roll(f, -h//2, 0) 101 f = np.roll(f, -w//2, 1) 102 kernel = np.uint8( (f-f.min()) / f.ptp()*255 ) 103 resp = self.last_resp 104 resp = np.uint8(np.clip(resp/resp.max(), 0, 1)*255) 105 vis = np.hstack([self.last_img, kernel, resp]) 106 return vis 107 108 def draw_state(self, vis): 109 (x, y), (w, h) = self.pos, self.size 110 x1, y1, x2, y2 = int(x-0.5*w), int(y-0.5*h), int(x+0.5*w), int(y+0.5*h) 111 cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 0, 255)) 112 if self.good: 113 cv2.circle(vis, (int(x), int(y)), 2, (0, 0, 255), -1) 114 else: 115 cv2.line(vis, (x1, y1), (x2, y2), (0, 0, 255)) 116 cv2.line(vis, (x2, y1), (x1, y2), (0, 0, 255)) 117 draw_str(vis, (x1, y2+16), 'PSR: %.2f' % self.psr) 118 119 def preprocess(self, img): 120 img = np.log(np.float32(img)+1.0) 121 img = (img-img.mean()) / (img.std()+eps) 122 return img*self.win 123 124 def correlate(self, img): 125 C = cv2.mulSpectrums(cv2.dft(img, flags=cv2.DFT_COMPLEX_OUTPUT), self.H, 0, conjB=True) 126 resp = cv2.idft(C, flags=cv2.DFT_SCALE | cv2.DFT_REAL_OUTPUT) 127 h, w = resp.shape 128 _, mval, _, (mx, my) = cv2.minMaxLoc(resp) 129 side_resp = resp.copy() 130 cv2.rectangle(side_resp, (mx-5, my-5), (mx+5, my+5), 0, -1) 131 smean, sstd = side_resp.mean(), side_resp.std() 132 psr = (mval-smean) / (sstd+eps) 133 return resp, (mx-w//2, my-h//2), psr 134 135 def update_kernel(self): 136 self.H = divSpec(self.H1, self.H2) 137 self.H[...,1] *= -1 138 139 class App: 140 def __init__(self, video_src, paused = False): 141 self.cap = video.create_capture(video_src) 142 _, self.frame = self.cap.read() 143 cv2.imshow('frame', self.frame) 144 self.rect_sel = RectSelector('frame', self.onrect) 145 self.trackers = [] 146 self.paused = paused 147 148 def onrect(self, rect): 149 frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY) 150 tracker = MOSSE(frame_gray, rect) 151 self.trackers.append(tracker) 152 153 def run(self): 154 while True: 155 if not self.paused: 156 ret, self.frame = self.cap.read() 157 if not ret: 158 break 159 frame_gray = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY) 160 for tracker in self.trackers: 161 tracker.update(frame_gray) 162 163 vis = self.frame.copy() 164 for tracker in self.trackers: 165 tracker.draw_state(vis) 166 if len(self.trackers) > 0: 167 cv2.imshow('tracker state', self.trackers[-1].state_vis) 168 self.rect_sel.draw(vis) 169 170 cv2.imshow('frame', vis) 171 ch = cv2.waitKey(10) & 0xFF 172 if ch == 27: 173 break 174 if ch == ord(' '): 175 self.paused = not self.paused 176 if ch == ord('c'): 177 self.trackers = [] 178 179 180 if __name__ == '__main__': 181 print __doc__ 182 import sys, getopt 183 opts, args = getopt.getopt(sys.argv[1:], '', ['pause']) 184 opts = dict(opts) 185 try: 186 video_src = args[0] 187 except: 188 video_src = '0' 189 190 App(video_src, paused = '--pause' in opts).run() 191