Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // NOTE: no native object detectors are currently provided or used by the code
     17 // in this directory. This class remains mainly for historical reasons.
     18 // Detection in the TF demo is done through TensorFlowMultiBoxDetector.java.
     19 
     20 // Defines the ObjectDetector class that is the main interface for detecting
     21 // ObjectModelBases in frames.
     22 
     23 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
     24 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
     25 
     26 #include <float.h>
     27 #include <map>
     28 #include <memory>
     29 #include <sstream>
     30 #include <string>
     31 #include <vector>
     32 
     33 #include "tensorflow/examples/android/jni/object_tracking/geom.h"
     34 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
     35 #include "tensorflow/examples/android/jni/object_tracking/image.h"
     36 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
     37 #ifdef __RENDER_OPENGL__
     38 #include "tensorflow/examples/android/jni/object_tracking/sprite.h"
     39 #endif
     40 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
     41 
     42 #include "tensorflow/examples/android/jni/object_tracking/config.h"
     43 #include "tensorflow/examples/android/jni/object_tracking/image_data.h"
     44 #include "tensorflow/examples/android/jni/object_tracking/object_model.h"
     45 
     46 namespace tf_tracking {
     47 
     48 // Adds BoundingSquares to a vector such that the first square added is centered
     49 // in the position given and of square_size, and the remaining squares are added
     50 // concentrentically, scaling down by scale_factor until the minimum threshold
     51 // size is passed.
     52 // Squares that do not fall completely within image_bounds will not be added.
     53 static inline void FillWithSquares(
     54     const BoundingBox& image_bounds,
     55     const BoundingBox& position,
     56     const float starting_square_size,
     57     const float smallest_square_size,
     58     const float scale_factor,
     59     std::vector<BoundingSquare>* const squares) {
     60   BoundingSquare descriptor_area =
     61       GetCenteredSquare(position, starting_square_size);
     62 
     63   SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor);
     64 
     65   // Use a do/while loop to ensure that at least one descriptor is created.
     66   do {
     67     if (image_bounds.Contains(descriptor_area.ToBoundingBox())) {
     68       squares->push_back(descriptor_area);
     69     }
     70     descriptor_area.Scale(scale_factor);
     71   } while (descriptor_area.size_ >= smallest_square_size - EPSILON);
     72   LOGV("Created %zu squares starting from size %.2f to min size %.2f "
     73        "using scale factor: %.2f",
     74        squares->size(), starting_square_size, smallest_square_size,
     75        scale_factor);
     76 }
     77 
     78 
     79 // Represents a potential detection of a specific ObjectExemplar and Descriptor
     80 // at a specific position in the image.
     81 class Detection {
     82  public:
     83   explicit Detection(const ObjectModelBase* const object_model,
     84                      const MatchScore match_score,
     85                      const BoundingBox& bounding_box)
     86       : object_model_(object_model),
     87         match_score_(match_score),
     88         bounding_box_(bounding_box) {}
     89 
     90   Detection(const Detection& other)
     91       : object_model_(other.object_model_),
     92         match_score_(other.match_score_),
     93         bounding_box_(other.bounding_box_) {}
     94 
     95   virtual ~Detection() {}
     96 
     97   inline BoundingBox GetObjectBoundingBox() const {
     98     return bounding_box_;
     99   }
    100 
    101   inline MatchScore GetMatchScore() const {
    102     return match_score_;
    103   }
    104 
    105   inline const ObjectModelBase* GetObjectModel() const {
    106     return object_model_;
    107   }
    108 
    109   inline bool Intersects(const Detection& other) {
    110     // Check if any of the four axes separates us, there must be at least one.
    111     return bounding_box_.Intersects(other.bounding_box_);
    112   }
    113 
    114   struct Comp {
    115     inline bool operator()(const Detection& a, const Detection& b) const {
    116       return a.match_score_ > b.match_score_;
    117     }
    118   };
    119 
    120   // TODO(andrewharp): add accessors to update these instead.
    121   const ObjectModelBase* object_model_;
    122   MatchScore match_score_;
    123   BoundingBox bounding_box_;
    124 };
    125 
    126 inline std::ostream& operator<<(std::ostream& stream,
    127                                 const Detection& detection) {
    128   const BoundingBox actual_area = detection.GetObjectBoundingBox();
    129   stream << actual_area;
    130   return stream;
    131 }
    132 
    133 class ObjectDetectorBase {
    134  public:
    135   explicit ObjectDetectorBase(const ObjectDetectorConfig* const config)
    136       : config_(config),
    137         image_data_(NULL) {}
    138 
    139   virtual ~ObjectDetectorBase();
    140 
    141   // Sets the current image data. All calls to ObjectDetector other than
    142   // FillDescriptors use the image data last set.
    143   inline void SetImageData(const ImageData* const image_data) {
    144     image_data_ = image_data;
    145   }
    146 
    147   // Main entry point into the detection algorithm.
    148   // Scans the frame for candidates, tweaks them, and fills in the
    149   // given std::vector of Detection objects with acceptable matches.
    150   virtual void Detect(const std::vector<BoundingSquare>& positions,
    151                       std::vector<Detection>* const detections) const = 0;
    152 
    153   virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0;
    154 
    155   virtual void DeleteObjectModel(const std::string& name) = 0;
    156 
    157   virtual void GetObjectModels(
    158       std::vector<const ObjectModelBase*>* models) const = 0;
    159 
    160   // Creates a new ObjectExemplar from the given position in the context of
    161   // the last frame passed to NextFrame.
    162   // Will return null in the case that there's no room for a descriptor to be
    163   // created in the example area, or the example area is not completely
    164   // contained within the frame.
    165   virtual void UpdateModel(const Image<uint8_t>& base_image,
    166                            const IntegralImage& integral_image,
    167                            const BoundingBox& bounding_box, const bool locked,
    168                            ObjectModelBase* model) const = 0;
    169 
    170   virtual void Draw() const = 0;
    171 
    172   virtual bool AllowSpontaneousDetections() = 0;
    173 
    174  protected:
    175   const std::unique_ptr<const ObjectDetectorConfig> config_;
    176 
    177   // The latest frame data, upon which all detections will be performed.
    178   // Not owned by this object, just provided for reference by ObjectTracker
    179   // via SetImageData().
    180   const ImageData* image_data_;
    181 
    182  private:
    183   TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase);
    184 };
    185 
    186 template <typename ModelType>
    187 class ObjectDetector : public ObjectDetectorBase {
    188  public:
    189   explicit ObjectDetector(const ObjectDetectorConfig* const config)
    190       : ObjectDetectorBase(config) {}
    191 
    192   virtual ~ObjectDetector() {
    193     typename std::map<std::string, ModelType*>::const_iterator it =
    194         object_models_.begin();
    195     for (; it != object_models_.end(); ++it) {
    196       ModelType* model = it->second;
    197       delete model;
    198     }
    199   }
    200 
    201   virtual void DeleteObjectModel(const std::string& name) {
    202     ModelType* model = object_models_[name];
    203     CHECK_ALWAYS(model != NULL, "Model was null!");
    204     object_models_.erase(name);
    205     SAFE_DELETE(model);
    206   }
    207 
    208   virtual void GetObjectModels(
    209       std::vector<const ObjectModelBase*>* models) const {
    210     typename std::map<std::string, ModelType*>::const_iterator it =
    211         object_models_.begin();
    212     for (; it != object_models_.end(); ++it) {
    213       models->push_back(it->second);
    214     }
    215   }
    216 
    217   virtual bool AllowSpontaneousDetections() {
    218     return false;
    219   }
    220 
    221  protected:
    222   std::map<std::string, ModelType*> object_models_;
    223 
    224  private:
    225   TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector);
    226 };
    227 
    228 }  // namespace tf_tracking
    229 
    230 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
    231