Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
     17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
     18 
     19 #ifdef __RENDER_OPENGL__
     20 #include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
     21 #endif
     22 #include "tensorflow/examples/android/jni/object_tracking/object_detector.h"
     23 
     24 namespace tf_tracking {
     25 
     26 // A TrackedObject is a specific instance of an ObjectModel, with a known
     27 // position in the world.
     28 // It provides the last known position and number of recent detection failures,
     29 // in addition to the more general appearance data associated with the object
     30 // class (which is in ObjectModel).
     31 // TODO(andrewharp): Make getters/setters follow styleguide.
     32 class TrackedObject {
     33  public:
     34   TrackedObject(const std::string& id, const Image<uint8_t>& image,
     35                 const BoundingBox& bounding_box, ObjectModelBase* const model);
     36 
     37   ~TrackedObject();
     38 
     39   void UpdatePosition(const BoundingBox& new_position, const int64_t timestamp,
     40                       const ImageData& image_data, const bool authoratative);
     41 
     42   // This method is called when the tracked object is detected at a
     43   // given position, and allows the associated Model to grow and/or prune
     44   // itself based on where the detection occurred.
     45   void OnDetection(ObjectModelBase* const model,
     46                    const BoundingBox& detection_position,
     47                    const MatchScore match_score, const int64_t timestamp,
     48                    const ImageData& image_data);
     49 
     50   // Called when there's no detection of the tracked object. This will cause
     51   // a tracking failure after enough consecutive failures if the area under
     52   // the current bounding box also doesn't meet a minimum correlation threshold
     53   // with the model.
     54   void OnDetectionFailure() {}
     55 
     56   inline bool IsVisible() const {
     57     return tracked_correlation_ >= kMinimumCorrelationForTracking ||
     58         num_consecutive_frames_below_threshold_ < kMaxNumDetectionFailures;
     59   }
     60 
     61   inline float GetCorrelation() {
     62     return tracked_correlation_;
     63   }
     64 
     65   inline MatchScore GetMatchScore() {
     66     return tracked_match_score_;
     67   }
     68 
     69   inline BoundingBox GetPosition() const {
     70     return last_known_position_;
     71   }
     72 
     73   inline BoundingBox GetLastDetectionPosition() const {
     74     return last_detection_position_;
     75   }
     76 
     77   inline const ObjectModelBase* GetModel() const {
     78     return object_model_;
     79   }
     80 
     81   inline const std::string& GetName() const {
     82     return id_;
     83   }
     84 
     85   inline void Draw() const {
     86 #ifdef __RENDER_OPENGL__
     87     if (tracked_correlation_ < kMinimumCorrelationForTracking) {
     88       glColor4f(MAX(0.0f, -tracked_correlation_),
     89                 MAX(0.0f, tracked_correlation_),
     90                 0.0f,
     91                 1.0f);
     92     } else {
     93       glColor4f(MAX(0.0f, -tracked_correlation_),
     94                 MAX(0.0f, tracked_correlation_),
     95                 1.0f,
     96                 1.0f);
     97     }
     98 
     99     // Render the box itself.
    100     BoundingBox temp_box(last_known_position_);
    101     DrawBox(temp_box);
    102 
    103     // Render a box inside this one (in case the actual box is hidden).
    104     const float kBufferSize = 1.0f;
    105     temp_box.left_ -= kBufferSize;
    106     temp_box.top_ -= kBufferSize;
    107     temp_box.right_ += kBufferSize;
    108     temp_box.bottom_ += kBufferSize;
    109     DrawBox(temp_box);
    110 
    111     // Render one outside as well.
    112     temp_box.left_ -= -2.0f * kBufferSize;
    113     temp_box.top_ -= -2.0f * kBufferSize;
    114     temp_box.right_ += -2.0f * kBufferSize;
    115     temp_box.bottom_ += -2.0f * kBufferSize;
    116     DrawBox(temp_box);
    117 #endif
    118   }
    119 
    120   // Get current object's num_consecutive_frames_below_threshold_.
    121   inline int64_t GetNumConsecutiveFramesBelowThreshold() {
    122     return num_consecutive_frames_below_threshold_;
    123   }
    124 
    125   // Reset num_consecutive_frames_below_threshold_ to 0.
    126   inline void resetNumConsecutiveFramesBelowThreshold() {
    127     num_consecutive_frames_below_threshold_ = 0;
    128   }
    129 
    130   inline float GetAllowableDistanceSquared() const {
    131     return allowable_detection_distance_;
    132   }
    133 
    134  private:
    135   // The unique id used throughout the system to identify this
    136   // tracked object.
    137   const std::string id_;
    138 
    139   // The last known position of the object.
    140   BoundingBox last_known_position_;
    141 
    142   // The last known position of the object.
    143   BoundingBox last_detection_position_;
    144 
    145   // When the position was last computed.
    146   int64_t position_last_computed_time_;
    147 
    148   // The object model this tracked object is representative of.
    149   ObjectModelBase* object_model_;
    150 
    151   Image<float> last_detection_thumbnail_;
    152 
    153   Image<float> last_frame_thumbnail_;
    154 
    155   // The correlation of the object model with the preview frame at its last
    156   // tracked position.
    157   float tracked_correlation_;
    158 
    159   MatchScore tracked_match_score_;
    160 
    161   // The number of consecutive frames that the tracked position for this object
    162   // has been under the correlation threshold.
    163   int num_consecutive_frames_below_threshold_;
    164 
    165   float allowable_detection_distance_;
    166 
    167   friend std::ostream& operator<<(std::ostream& stream,
    168                                   const TrackedObject& tracked_object);
    169 
    170   TF_DISALLOW_COPY_AND_ASSIGN(TrackedObject);
    171 };
    172 
    173 inline std::ostream& operator<<(std::ostream& stream,
    174                                 const TrackedObject& tracked_object) {
    175   stream << tracked_object.id_
    176       << " " << tracked_object.last_known_position_
    177       << " " << tracked_object.position_last_computed_time_
    178       << " " << tracked_object.num_consecutive_frames_below_threshold_
    179       << " " << tracked_object.object_model_
    180       << " " << tracked_object.tracked_correlation_;
    181   return stream;
    182 }
    183 
    184 }  // namespace tf_tracking
    185 
    186 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
    187