Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <float.h>
     17 
     18 #include "tensorflow/examples/android/jni/object_tracking/config.h"
     19 #include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
     20 
     21 namespace tf_tracking {
     22 
     23 void FramePair::Init(const int64_t start_time, const int64_t end_time) {
     24   start_time_ = start_time;
     25   end_time_ = end_time;
     26   memset(optical_flow_found_keypoint_, false,
     27          sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
     28   number_of_keypoints_ = 0;
     29 }
     30 
     31 void FramePair::AdjustBox(const BoundingBox box,
     32                           float* const translation_x,
     33                           float* const translation_y,
     34                           float* const scale_x,
     35                           float* const scale_y) const {
     36   static float weights[kMaxKeypoints];
     37   static Point2f deltas[kMaxKeypoints];
     38   memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
     39 
     40   BoundingBox resized_box(box);
     41   resized_box.Scale(0.4f, 0.4f);
     42   FillWeights(resized_box, weights);
     43   FillTranslations(deltas);
     44 
     45   const Point2f translation = GetWeightedMedian(weights, deltas);
     46 
     47   *translation_x = translation.x;
     48   *translation_y = translation.y;
     49 
     50   const Point2f old_center = box.GetCenter();
     51   const int good_scale_points =
     52       FillScales(old_center, translation, weights, deltas);
     53 
     54   // Default scale factor is 1 for x and y.
     55   *scale_x = 1.0f;
     56   *scale_y = 1.0f;
     57 
     58   // The assumption is that all deltas that make it to this stage with a
     59   // correspondending optical_flow_found_keypoint_[i] == true are not in
     60   // themselves degenerate.
     61   //
     62   // The degeneracy with scale arose because if the points are too close to the
     63   // center of the objects, the scale ratio determination might be incalculable.
     64   //
     65   // The check for kMinNumInRange is not a degeneracy check, but merely an
     66   // attempt to ensure some sort of stability. The actual degeneracy check is in
     67   // the comparison to EPSILON in FillScales (which I've updated to return the
     68   // number good remaining as well).
     69   static const int kMinNumInRange = 5;
     70   if (good_scale_points >= kMinNumInRange) {
     71     const float scale_factor = GetWeightedMedianScale(weights, deltas);
     72 
     73     if (scale_factor > 0.0f) {
     74       *scale_x = scale_factor;
     75       *scale_y = scale_factor;
     76     }
     77   }
     78 }
     79 
     80 int FramePair::FillWeights(const BoundingBox& box,
     81                            float* const weights) const {
     82   // Compute the max score.
     83   float max_score = -FLT_MAX;
     84   float min_score = FLT_MAX;
     85   for (int i = 0; i < kMaxKeypoints; ++i) {
     86     if (optical_flow_found_keypoint_[i]) {
     87       max_score = MAX(max_score, frame1_keypoints_[i].score_);
     88       min_score = MIN(min_score, frame1_keypoints_[i].score_);
     89     }
     90   }
     91 
     92   int num_in_range = 0;
     93   for (int i = 0; i < kMaxKeypoints; ++i) {
     94     if (!optical_flow_found_keypoint_[i]) {
     95       weights[i] = 0.0f;
     96       continue;
     97     }
     98 
     99     const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
    100     if (in_box) {
    101       ++num_in_range;
    102     }
    103 
    104     // The weighting based off distance.  Anything within the bounding box
    105     // has a weight of 1, and everything outside of that is within the range
    106     // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
    107     float distance_score = 1.0f;
    108     if (!in_box) {
    109       const Point2f initial = box.GetCenter();
    110       const float sq_x_dist =
    111           Square(initial.x - frame1_keypoints_[i].pos_.x);
    112       const float sq_y_dist =
    113           Square(initial.y - frame1_keypoints_[i].pos_.y);
    114       const float squared_half_width = Square(box.GetWidth() / 2.0f);
    115       const float squared_half_height = Square(box.GetHeight() / 2.0f);
    116 
    117       static const float kOutOfBoxMultiplier = 0.5f;
    118       distance_score = kOutOfBoxMultiplier *
    119           MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
    120     }
    121 
    122     // The weighting based on relative score strength. kBaseScore - 1.0f.
    123     float intrinsic_score =  1.0f;
    124     if (max_score > min_score) {
    125       static const float kBaseScore = 0.5f;
    126       intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
    127          (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
    128     }
    129 
    130     // The final score will be in the range [0, 1].
    131     weights[i] = distance_score * intrinsic_score;
    132   }
    133 
    134   return num_in_range;
    135 }
    136 
    137 void FramePair::FillTranslations(Point2f* const translations) const {
    138   for (int i = 0; i < kMaxKeypoints; ++i) {
    139     if (!optical_flow_found_keypoint_[i]) {
    140       continue;
    141     }
    142     translations[i].x =
    143         frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
    144     translations[i].y =
    145         frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
    146   }
    147 }
    148 
    149 int FramePair::FillScales(const Point2f& old_center,
    150                           const Point2f& translation,
    151                           float* const weights,
    152                           Point2f* const scales) const {
    153   int num_good = 0;
    154   for (int i = 0; i < kMaxKeypoints; ++i) {
    155     if (!optical_flow_found_keypoint_[i]) {
    156       continue;
    157     }
    158 
    159     const Keypoint keypoint1 = frame1_keypoints_[i];
    160     const Keypoint keypoint2 = frame2_keypoints_[i];
    161 
    162     const float dist1_x = keypoint1.pos_.x - old_center.x;
    163     const float dist1_y = keypoint1.pos_.y - old_center.y;
    164 
    165     const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
    166     const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
    167 
    168     // Make sure that the scale makes sense; points too close to the center
    169     // will result in either NaNs or infinite results for scale due to
    170     // limited tracking and floating point resolution.
    171     // Also check that the parity of the points is the same with respect to
    172     // x and y, as we can't really make sense of data that has flipped.
    173     if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
    174          (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
    175          ((dist2_y > EPSILON && dist1_y > EPSILON) ||
    176           (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
    177       scales[i].x = dist2_x / dist1_x;
    178       scales[i].y = dist2_y / dist1_y;
    179       ++num_good;
    180     } else {
    181       weights[i] = 0.0f;
    182       scales[i].x = 1.0f;
    183       scales[i].y = 1.0f;
    184     }
    185   }
    186   return num_good;
    187 }
    188 
    189 struct WeightedDelta {
    190   float weight;
    191   float delta;
    192 };
    193 
    194 // Sort by delta, not by weight.
    195 inline int WeightedDeltaCompare(const void* const a, const void* const b) {
    196   return (reinterpret_cast<const WeightedDelta*>(a)->delta -
    197           reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
    198 }
    199 
    200 // Returns the median delta from a sorted set of weighted deltas.
    201 static float GetMedian(const int num_items,
    202                        const WeightedDelta* const weighted_deltas,
    203                        const float sum) {
    204   if (num_items == 0 || sum < EPSILON) {
    205     return 0.0f;
    206   }
    207 
    208   float current_weight = 0.0f;
    209   const float target_weight = sum / 2.0f;
    210   for (int i = 0; i < num_items; ++i) {
    211     if (weighted_deltas[i].weight > 0.0f) {
    212       current_weight += weighted_deltas[i].weight;
    213       if (current_weight >= target_weight) {
    214         return weighted_deltas[i].delta;
    215       }
    216     }
    217   }
    218   LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
    219   return 0.0f;
    220 }
    221 
    222 Point2f FramePair::GetWeightedMedian(
    223     const float* const weights, const Point2f* const deltas) const {
    224   Point2f median_delta;
    225 
    226   // TODO(andrewharp): only sort deltas that could possibly have an effect.
    227   static WeightedDelta weighted_deltas[kMaxKeypoints];
    228 
    229   // Compute median X value.
    230   {
    231     float total_weight = 0.0f;
    232 
    233     // Compute weighted mean and deltas.
    234     for (int i = 0; i < kMaxKeypoints; ++i) {
    235       weighted_deltas[i].delta = deltas[i].x;
    236       const float weight = weights[i];
    237       weighted_deltas[i].weight = weight;
    238       if (weight > 0.0f) {
    239         total_weight += weight;
    240       }
    241     }
    242     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
    243           WeightedDeltaCompare);
    244     median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
    245   }
    246 
    247   // Compute median Y value.
    248   {
    249     float total_weight = 0.0f;
    250 
    251     // Compute weighted mean and deltas.
    252     for (int i = 0; i < kMaxKeypoints; ++i) {
    253       const float weight = weights[i];
    254       weighted_deltas[i].weight = weight;
    255       weighted_deltas[i].delta = deltas[i].y;
    256       if (weight > 0.0f) {
    257         total_weight += weight;
    258       }
    259     }
    260     qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
    261           WeightedDeltaCompare);
    262     median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
    263   }
    264 
    265   return median_delta;
    266 }
    267 
    268 float FramePair::GetWeightedMedianScale(
    269     const float* const weights, const Point2f* const deltas) const {
    270   float median_delta;
    271 
    272   // TODO(andrewharp): only sort deltas that could possibly have an effect.
    273   static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
    274 
    275   // Compute median scale value across x and y.
    276   {
    277     float total_weight = 0.0f;
    278 
    279     // Add X values.
    280     for (int i = 0; i < kMaxKeypoints; ++i) {
    281       weighted_deltas[i].delta = deltas[i].x;
    282       const float weight = weights[i];
    283       weighted_deltas[i].weight = weight;
    284       if (weight > 0.0f) {
    285         total_weight += weight;
    286       }
    287     }
    288 
    289     // Add Y values.
    290     for (int i = 0; i < kMaxKeypoints; ++i) {
    291       weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
    292       const float weight = weights[i];
    293       weighted_deltas[i + kMaxKeypoints].weight = weight;
    294       if (weight > 0.0f) {
    295         total_weight += weight;
    296       }
    297     }
    298 
    299     qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
    300           WeightedDeltaCompare);
    301 
    302     median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
    303   }
    304 
    305   return median_delta;
    306 }
    307 
    308 }  // namespace tf_tracking
    309