1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // NOTE: no native object detectors are currently provided or used by the code 17 // in this directory. This class remains mainly for historical reasons. 18 // Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. 19 20 // Defines the ObjectDetector class that is the main interface for detecting 21 // ObjectModelBases in frames. 22 23 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ 24 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ 25 26 #include <float.h> 27 #include <map> 28 #include <memory> 29 #include <sstream> 30 #include <string> 31 #include <vector> 32 33 #include "tensorflow/examples/android/jni/object_tracking/geom.h" 34 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h" 35 #include "tensorflow/examples/android/jni/object_tracking/image.h" 36 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h" 37 #ifdef __RENDER_OPENGL__ 38 #include "tensorflow/examples/android/jni/object_tracking/sprite.h" 39 #endif 40 #include "tensorflow/examples/android/jni/object_tracking/utils.h" 41 42 #include "tensorflow/examples/android/jni/object_tracking/config.h" 43 #include "tensorflow/examples/android/jni/object_tracking/image_data.h" 44 #include "tensorflow/examples/android/jni/object_tracking/object_model.h" 45 46 namespace tf_tracking { 47 48 // Adds BoundingSquares to a vector such that the first square added is centered 49 // in the position given and of square_size, and the remaining squares are added 50 // concentrentically, scaling down by scale_factor until the minimum threshold 51 // size is passed. 52 // Squares that do not fall completely within image_bounds will not be added. 53 static inline void FillWithSquares( 54 const BoundingBox& image_bounds, 55 const BoundingBox& position, 56 const float starting_square_size, 57 const float smallest_square_size, 58 const float scale_factor, 59 std::vector<BoundingSquare>* const squares) { 60 BoundingSquare descriptor_area = 61 GetCenteredSquare(position, starting_square_size); 62 63 SCHECK(scale_factor < 1.0f, "Scale factor too large at %.2f!", scale_factor); 64 65 // Use a do/while loop to ensure that at least one descriptor is created. 66 do { 67 if (image_bounds.Contains(descriptor_area.ToBoundingBox())) { 68 squares->push_back(descriptor_area); 69 } 70 descriptor_area.Scale(scale_factor); 71 } while (descriptor_area.size_ >= smallest_square_size - EPSILON); 72 LOGV("Created %zu squares starting from size %.2f to min size %.2f " 73 "using scale factor: %.2f", 74 squares->size(), starting_square_size, smallest_square_size, 75 scale_factor); 76 } 77 78 79 // Represents a potential detection of a specific ObjectExemplar and Descriptor 80 // at a specific position in the image. 81 class Detection { 82 public: 83 explicit Detection(const ObjectModelBase* const object_model, 84 const MatchScore match_score, 85 const BoundingBox& bounding_box) 86 : object_model_(object_model), 87 match_score_(match_score), 88 bounding_box_(bounding_box) {} 89 90 Detection(const Detection& other) 91 : object_model_(other.object_model_), 92 match_score_(other.match_score_), 93 bounding_box_(other.bounding_box_) {} 94 95 virtual ~Detection() {} 96 97 inline BoundingBox GetObjectBoundingBox() const { 98 return bounding_box_; 99 } 100 101 inline MatchScore GetMatchScore() const { 102 return match_score_; 103 } 104 105 inline const ObjectModelBase* GetObjectModel() const { 106 return object_model_; 107 } 108 109 inline bool Intersects(const Detection& other) { 110 // Check if any of the four axes separates us, there must be at least one. 111 return bounding_box_.Intersects(other.bounding_box_); 112 } 113 114 struct Comp { 115 inline bool operator()(const Detection& a, const Detection& b) const { 116 return a.match_score_ > b.match_score_; 117 } 118 }; 119 120 // TODO(andrewharp): add accessors to update these instead. 121 const ObjectModelBase* object_model_; 122 MatchScore match_score_; 123 BoundingBox bounding_box_; 124 }; 125 126 inline std::ostream& operator<<(std::ostream& stream, 127 const Detection& detection) { 128 const BoundingBox actual_area = detection.GetObjectBoundingBox(); 129 stream << actual_area; 130 return stream; 131 } 132 133 class ObjectDetectorBase { 134 public: 135 explicit ObjectDetectorBase(const ObjectDetectorConfig* const config) 136 : config_(config), 137 image_data_(NULL) {} 138 139 virtual ~ObjectDetectorBase(); 140 141 // Sets the current image data. All calls to ObjectDetector other than 142 // FillDescriptors use the image data last set. 143 inline void SetImageData(const ImageData* const image_data) { 144 image_data_ = image_data; 145 } 146 147 // Main entry point into the detection algorithm. 148 // Scans the frame for candidates, tweaks them, and fills in the 149 // given std::vector of Detection objects with acceptable matches. 150 virtual void Detect(const std::vector<BoundingSquare>& positions, 151 std::vector<Detection>* const detections) const = 0; 152 153 virtual ObjectModelBase* CreateObjectModel(const std::string& name) = 0; 154 155 virtual void DeleteObjectModel(const std::string& name) = 0; 156 157 virtual void GetObjectModels( 158 std::vector<const ObjectModelBase*>* models) const = 0; 159 160 // Creates a new ObjectExemplar from the given position in the context of 161 // the last frame passed to NextFrame. 162 // Will return null in the case that there's no room for a descriptor to be 163 // created in the example area, or the example area is not completely 164 // contained within the frame. 165 virtual void UpdateModel(const Image<uint8_t>& base_image, 166 const IntegralImage& integral_image, 167 const BoundingBox& bounding_box, const bool locked, 168 ObjectModelBase* model) const = 0; 169 170 virtual void Draw() const = 0; 171 172 virtual bool AllowSpontaneousDetections() = 0; 173 174 protected: 175 const std::unique_ptr<const ObjectDetectorConfig> config_; 176 177 // The latest frame data, upon which all detections will be performed. 178 // Not owned by this object, just provided for reference by ObjectTracker 179 // via SetImageData(). 180 const ImageData* image_data_; 181 182 private: 183 TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetectorBase); 184 }; 185 186 template <typename ModelType> 187 class ObjectDetector : public ObjectDetectorBase { 188 public: 189 explicit ObjectDetector(const ObjectDetectorConfig* const config) 190 : ObjectDetectorBase(config) {} 191 192 virtual ~ObjectDetector() { 193 typename std::map<std::string, ModelType*>::const_iterator it = 194 object_models_.begin(); 195 for (; it != object_models_.end(); ++it) { 196 ModelType* model = it->second; 197 delete model; 198 } 199 } 200 201 virtual void DeleteObjectModel(const std::string& name) { 202 ModelType* model = object_models_[name]; 203 CHECK_ALWAYS(model != NULL, "Model was null!"); 204 object_models_.erase(name); 205 SAFE_DELETE(model); 206 } 207 208 virtual void GetObjectModels( 209 std::vector<const ObjectModelBase*>* models) const { 210 typename std::map<std::string, ModelType*>::const_iterator it = 211 object_models_.begin(); 212 for (; it != object_models_.end(); ++it) { 213 models->push_back(it->second); 214 } 215 } 216 217 virtual bool AllowSpontaneousDetections() { 218 return false; 219 } 220 221 protected: 222 std::map<std::string, ModelType*> object_models_; 223 224 private: 225 TF_DISALLOW_COPY_AND_ASSIGN(ObjectDetector); 226 }; 227 228 } // namespace tf_tracking 229 230 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ 231