Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
     17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
     18 
     19 #include "tensorflow/examples/android/jni/object_tracking/geom.h"
     20 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
     21 
     22 #include "tensorflow/examples/android/jni/object_tracking/config.h"
     23 #include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
     24 
     25 namespace tf_tracking {
     26 
     27 // Class that helps OpticalFlow to speed up flow computation
     28 // by caching coarse-grained flow.
     29 class FlowCache {
     30  public:
     31   explicit FlowCache(const OpticalFlowConfig* const config)
     32       : config_(config),
     33         image_size_(config->image_size),
     34         optical_flow_(config),
     35         fullframe_matrix_(NULL) {
     36     for (int i = 0; i < kNumCacheLevels; ++i) {
     37       const int curr_dims = BlockDimForCacheLevel(i);
     38       has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
     39       displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
     40     }
     41   }
     42 
     43   ~FlowCache() {
     44     for (int i = 0; i < kNumCacheLevels; ++i) {
     45       SAFE_DELETE(has_cache_[i]);
     46       SAFE_DELETE(displacements_[i]);
     47     }
     48     delete[](fullframe_matrix_);
     49     fullframe_matrix_ = NULL;
     50   }
     51 
     52   void NextFrame(ImageData* const new_frame,
     53                  const float* const align_matrix23) {
     54     ClearCache();
     55     SetFullframeAlignmentMatrix(align_matrix23);
     56     optical_flow_.NextFrame(new_frame);
     57   }
     58 
     59   void ClearCache() {
     60     for (int i = 0; i < kNumCacheLevels; ++i) {
     61       has_cache_[i]->Clear(false);
     62     }
     63     delete[](fullframe_matrix_);
     64     fullframe_matrix_ = NULL;
     65   }
     66 
     67   // Finds the flow at a point, using the cache for performance.
     68   bool FindFlowAtPoint(const float u_x, const float u_y,
     69                        float* const flow_x, float* const flow_y) const {
     70     // Get the best guess from the cache.
     71     const Point2f guess_from_cache = LookupGuess(u_x, u_y);
     72 
     73     *flow_x = guess_from_cache.x;
     74     *flow_y = guess_from_cache.y;
     75 
     76     // Now refine the guess using the image pyramid.
     77     for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
     78         pyramid_level >= 0; --pyramid_level) {
     79       if (!optical_flow_.FindFlowAtPointSingleLevel(
     80           pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
     81         return false;
     82       }
     83     }
     84 
     85     return true;
     86   }
     87 
     88   // Determines the displacement of a point, and uses that to calculate a new
     89   // position.
     90   // Returns true iff the displacement determination worked and the new position
     91   // is in the image.
     92   bool FindNewPositionOfPoint(const float u_x, const float u_y,
     93                               float* final_x, float* final_y) const {
     94     float flow_x;
     95     float flow_y;
     96     if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
     97       return false;
     98     }
     99 
    100     // Add in the displacement to get the final position.
    101     *final_x = u_x + flow_x;
    102     *final_y = u_y + flow_y;
    103 
    104     // Assign the best guess, if we're still in the image.
    105     if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
    106         InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
    107       return true;
    108     } else {
    109       return false;
    110     }
    111   }
    112 
    113   // Comparison function for qsort.
    114   static int Compare(const void* a, const void* b) {
    115     return *reinterpret_cast<const float*>(a) -
    116            *reinterpret_cast<const float*>(b);
    117   }
    118 
    119   // Returns the median flow within the given bounding box as determined
    120   // by a grid_width x grid_height grid.
    121   Point2f GetMedianFlow(const BoundingBox& bounding_box,
    122                         const bool filter_by_fb_error,
    123                         const int grid_width,
    124                         const int grid_height) const {
    125     const int kMaxPoints = 100;
    126     SCHECK(grid_width * grid_height <= kMaxPoints,
    127           "Too many points for Median flow!");
    128 
    129     const BoundingBox valid_box = bounding_box.Intersect(
    130         BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
    131 
    132     if (valid_box.GetArea() <= 0.0f) {
    133       return Point2f(0, 0);
    134     }
    135 
    136     float x_deltas[kMaxPoints];
    137     float y_deltas[kMaxPoints];
    138 
    139     int curr_offset = 0;
    140     for (int i = 0; i < grid_width; ++i) {
    141       for (int j = 0; j < grid_height; ++j) {
    142         const float x_in = valid_box.left_ +
    143             (valid_box.GetWidth() * i) / (grid_width - 1);
    144 
    145         const float y_in = valid_box.top_ +
    146             (valid_box.GetHeight() * j) / (grid_height - 1);
    147 
    148         float curr_flow_x;
    149         float curr_flow_y;
    150         const bool success = FindNewPositionOfPoint(x_in, y_in,
    151                                                     &curr_flow_x, &curr_flow_y);
    152 
    153         if (success) {
    154           x_deltas[curr_offset] = curr_flow_x;
    155           y_deltas[curr_offset] = curr_flow_y;
    156           ++curr_offset;
    157         } else {
    158           LOGW("Tracking failure!");
    159         }
    160       }
    161     }
    162 
    163     if (curr_offset > 0) {
    164       qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
    165       qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
    166 
    167       return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
    168     }
    169 
    170     LOGW("No points were valid!");
    171     return Point2f(0, 0);
    172   }
    173 
    174   void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
    175     if (align_matrix23 != NULL) {
    176       if (fullframe_matrix_ == NULL) {
    177         fullframe_matrix_ = new float[6];
    178       }
    179 
    180       memcpy(fullframe_matrix_, align_matrix23,
    181              6 * sizeof(fullframe_matrix_[0]));
    182     }
    183   }
    184 
    185  private:
    186   Point2f LookupGuessFromLevel(
    187       const int cache_level, const float x, const float y) const {
    188     // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
    189 
    190     // Cutoff at the target level and use the matrix transform instead.
    191     if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
    192       const float xnew = x * fullframe_matrix_[0] +
    193                          y * fullframe_matrix_[1] +
    194                              fullframe_matrix_[2];
    195       const float ynew = x * fullframe_matrix_[3] +
    196                          y * fullframe_matrix_[4] +
    197                              fullframe_matrix_[5];
    198 
    199       return Point2f(xnew - x, ynew - y);
    200     }
    201 
    202     const int level_dim = BlockDimForCacheLevel(cache_level);
    203     const int pixels_per_cache_block_x =
    204         (image_size_.width + level_dim - 1) / level_dim;
    205     const int pixels_per_cache_block_y =
    206         (image_size_.height + level_dim - 1) / level_dim;
    207     const int index_x = x / pixels_per_cache_block_x;
    208     const int index_y = y / pixels_per_cache_block_y;
    209 
    210     Point2f displacement;
    211     if (!(*has_cache_[cache_level])[index_y][index_x]) {
    212       (*has_cache_[cache_level])[index_y][index_x] = true;
    213 
    214       // Get the lower cache level's best guess, if it exists.
    215       displacement = cache_level >= kNumCacheLevels - 1 ?
    216           Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
    217       // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
    218       //      best_guess.x, best_guess.y);
    219 
    220       // Find the center of the block.
    221       const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
    222       const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
    223       const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
    224 
    225       // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
    226       //      "Querying %5.2f, %5.2f at pyramid level %d, ",
    227       //      cache_level, index_x, index_y,
    228       //      x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
    229       //      center_x, center_y, pyramid_level);
    230 
    231       // TODO(andrewharp): Turn on FB error filtering.
    232       const bool success = optical_flow_.FindFlowAtPointSingleLevel(
    233           pyramid_level, center_x, center_y, false,
    234           &displacement.x, &displacement.y);
    235 
    236       if (!success) {
    237         LOGV("Computation of cached value failed for level %d!", cache_level);
    238       }
    239 
    240       // Store the value for later use.
    241       (*displacements_[cache_level])[index_y][index_x] = displacement;
    242     } else {
    243       displacement = (*displacements_[cache_level])[index_y][index_x];
    244     }
    245 
    246     // LOGI("Returning %5.2f, %5.2f for level %d",
    247     //      displacement.x, displacement.y, cache_level);
    248     return displacement;
    249   }
    250 
    251   Point2f LookupGuess(const float x, const float y) const {
    252     if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
    253       return Point2f(0, 0);
    254     }
    255 
    256     // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
    257     if (kNumCacheLevels > 0) {
    258       return LookupGuessFromLevel(0, x, y);
    259     } else {
    260       return Point2f(0, 0);
    261     }
    262   }
    263 
    264   // Returns the number of cache bins in each dimension for a given level
    265   // of the cache.
    266   int BlockDimForCacheLevel(const int cache_level) const {
    267     // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
    268     // thus if there are 4 cache levels, requesting level 3 (0-based) should
    269     // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
    270     // and so on.
    271     int block_dim = kNumCacheLevels;
    272     for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
    273         --curr_level) {
    274       block_dim *= kCacheBranchFactor;
    275     }
    276     return block_dim;
    277   }
    278 
    279   // Returns the level of the image pyramid that a given cache level maps to.
    280   int PyramidLevelForCacheLevel(const int cache_level) const {
    281     // Higher cache and pyramid levels have smaller dimensions. The highest
    282     // cache level should refer to the highest image pyramid level. The
    283     // lower, finer image pyramid levels are uncached (assuming
    284     // kNumCacheLevels < kNumPyramidLevels).
    285     return cache_level + (kNumPyramidLevels - kNumCacheLevels);
    286   }
    287 
    288   const OpticalFlowConfig* const config_;
    289 
    290   const Size image_size_;
    291   OpticalFlow optical_flow_;
    292 
    293   float* fullframe_matrix_;
    294 
    295   // Whether this value is currently present in the cache.
    296   Image<bool>* has_cache_[kNumCacheLevels];
    297 
    298   // The cached displacement values.
    299   Image<Point2f>* displacements_[kNumCacheLevels];
    300 
    301   TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
    302 };
    303 
    304 }  // namespace tf_tracking
    305 
    306 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
    307