Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
     17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
     18 
     19 #include <stdint.h>
     20 #include <memory>
     21 
     22 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
     23 #include "tensorflow/examples/android/jni/object_tracking/image.h"
     24 #include "tensorflow/examples/android/jni/object_tracking/image_utils.h"
     25 #include "tensorflow/examples/android/jni/object_tracking/integral_image.h"
     26 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
     27 #include "tensorflow/examples/android/jni/object_tracking/utils.h"
     28 
     29 #include "tensorflow/examples/android/jni/object_tracking/config.h"
     30 
     31 namespace tf_tracking {
     32 
     33 // Class that encapsulates all bulky processed data for a frame.
     34 class ImageData {
     35  public:
     36   explicit ImageData(const int width, const int height)
     37       : uv_frame_width_(width << 1),
     38         uv_frame_height_(height << 1),
     39         timestamp_(0),
     40         image_(width, height) {
     41     InitPyramid(width, height);
     42     ResetComputationCache();
     43   }
     44 
     45  private:
     46   void ResetComputationCache() {
     47     uv_data_computed_ = false;
     48     integral_image_computed_ = false;
     49     for (int i = 0; i < kNumPyramidLevels; ++i) {
     50       spatial_x_computed_[i] = false;
     51       spatial_y_computed_[i] = false;
     52       pyramid_sqrt2_computed_[i * 2] = false;
     53       pyramid_sqrt2_computed_[i * 2 + 1] = false;
     54     }
     55   }
     56 
     57   void InitPyramid(const int width, const int height) {
     58     int level_width = width;
     59     int level_height = height;
     60 
     61     for (int i = 0; i < kNumPyramidLevels; ++i) {
     62       pyramid_sqrt2_[i * 2] = NULL;
     63       pyramid_sqrt2_[i * 2 + 1] = NULL;
     64       spatial_x_[i] = NULL;
     65       spatial_y_[i] = NULL;
     66 
     67       level_width /= 2;
     68       level_height /= 2;
     69     }
     70 
     71     // Alias the first pyramid level to image_.
     72     pyramid_sqrt2_[0] = &image_;
     73   }
     74 
     75  public:
     76   ~ImageData() {
     77     // The first pyramid level is actually an alias to image_,
     78     // so make sure it doesn't get deleted here.
     79     pyramid_sqrt2_[0] = NULL;
     80 
     81     for (int i = 0; i < kNumPyramidLevels; ++i) {
     82       SAFE_DELETE(pyramid_sqrt2_[i * 2]);
     83       SAFE_DELETE(pyramid_sqrt2_[i * 2 + 1]);
     84       SAFE_DELETE(spatial_x_[i]);
     85       SAFE_DELETE(spatial_y_[i]);
     86     }
     87   }
     88 
     89   void SetData(const uint8_t* const new_frame, const int stride,
     90                const int64_t timestamp, const int downsample_factor) {
     91     SetData(new_frame, NULL, stride, timestamp, downsample_factor);
     92   }
     93 
     94   void SetData(const uint8_t* const new_frame, const uint8_t* const uv_frame,
     95                const int stride, const int64_t timestamp,
     96                const int downsample_factor) {
     97     ResetComputationCache();
     98 
     99     timestamp_ = timestamp;
    100 
    101     TimeLog("SetData!");
    102 
    103     pyramid_sqrt2_[0]->FromArray(new_frame, stride, downsample_factor);
    104     pyramid_sqrt2_computed_[0] = true;
    105     TimeLog("Downsampled image");
    106 
    107     if (uv_frame != NULL) {
    108       if (u_data_.get() == NULL) {
    109         u_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
    110         v_data_.reset(new Image<uint8_t>(uv_frame_width_, uv_frame_height_));
    111       }
    112 
    113       GetUV(uv_frame, u_data_.get(), v_data_.get());
    114       uv_data_computed_ = true;
    115       TimeLog("Copied UV data");
    116     } else {
    117       LOGV("No uv data!");
    118     }
    119 
    120 #ifdef LOG_TIME
    121     // If profiling is enabled, precompute here to make it easier to distinguish
    122     // total costs.
    123     Precompute();
    124 #endif
    125   }
    126 
    127   inline const uint64_t GetTimestamp() const { return timestamp_; }
    128 
    129   inline const Image<uint8_t>* GetImage() const {
    130     SCHECK(pyramid_sqrt2_computed_[0], "image not set!");
    131     return pyramid_sqrt2_[0];
    132   }
    133 
    134   const Image<uint8_t>* GetPyramidSqrt2Level(const int level) const {
    135     if (!pyramid_sqrt2_computed_[level]) {
    136       SCHECK(level != 0, "Level equals 0!");
    137       if (level == 1) {
    138         const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(0);
    139         if (pyramid_sqrt2_[level] == NULL) {
    140           const int new_width =
    141               (static_cast<int>(upper_level.GetWidth() / sqrtf(2)) + 1) / 2 * 2;
    142           const int new_height =
    143               (static_cast<int>(upper_level.GetHeight() / sqrtf(2)) + 1) / 2 *
    144               2;
    145 
    146           pyramid_sqrt2_[level] = new Image<uint8_t>(new_width, new_height);
    147         }
    148         pyramid_sqrt2_[level]->DownsampleInterpolateLinear(upper_level);
    149       } else {
    150         const Image<uint8_t>& upper_level = *GetPyramidSqrt2Level(level - 2);
    151         if (pyramid_sqrt2_[level] == NULL) {
    152           pyramid_sqrt2_[level] = new Image<uint8_t>(
    153               upper_level.GetWidth() / 2, upper_level.GetHeight() / 2);
    154         }
    155         pyramid_sqrt2_[level]->DownsampleAveraged(
    156             upper_level.data(), upper_level.stride(), 2);
    157       }
    158       pyramid_sqrt2_computed_[level] = true;
    159     }
    160     return pyramid_sqrt2_[level];
    161   }
    162 
    163   inline const Image<int32_t>* GetSpatialX(const int level) const {
    164     if (!spatial_x_computed_[level]) {
    165       const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
    166       if (spatial_x_[level] == NULL) {
    167         spatial_x_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
    168       }
    169       spatial_x_[level]->DerivativeX(src);
    170       spatial_x_computed_[level] = true;
    171     }
    172     return spatial_x_[level];
    173   }
    174 
    175   inline const Image<int32_t>* GetSpatialY(const int level) const {
    176     if (!spatial_y_computed_[level]) {
    177       const Image<uint8_t>& src = *GetPyramidSqrt2Level(level * 2);
    178       if (spatial_y_[level] == NULL) {
    179         spatial_y_[level] = new Image<int32_t>(src.GetWidth(), src.GetHeight());
    180       }
    181       spatial_y_[level]->DerivativeY(src);
    182       spatial_y_computed_[level] = true;
    183     }
    184     return spatial_y_[level];
    185   }
    186 
    187   // The integral image is currently only used for object detection, so lazily
    188   // initialize it on request.
    189   inline const IntegralImage* GetIntegralImage() const {
    190     if (integral_image_.get() == NULL) {
    191       integral_image_.reset(new IntegralImage(image_));
    192     } else if (!integral_image_computed_) {
    193       integral_image_->Recompute(image_);
    194     }
    195     integral_image_computed_ = true;
    196     return integral_image_.get();
    197   }
    198 
    199   inline const Image<uint8_t>* GetU() const {
    200     SCHECK(uv_data_computed_, "UV data not provided!");
    201     return u_data_.get();
    202   }
    203 
    204   inline const Image<uint8_t>* GetV() const {
    205     SCHECK(uv_data_computed_, "UV data not provided!");
    206     return v_data_.get();
    207   }
    208 
    209  private:
    210   void Precompute() {
    211     // Create the smoothed pyramids.
    212     for (int i = 0; i < kNumPyramidLevels * 2; i += 2) {
    213       (void) GetPyramidSqrt2Level(i);
    214     }
    215     TimeLog("Created smoothed pyramids");
    216 
    217     // Create the smoothed pyramids.
    218     for (int i = 1; i < kNumPyramidLevels * 2; i += 2) {
    219       (void) GetPyramidSqrt2Level(i);
    220     }
    221     TimeLog("Created smoothed sqrt pyramids");
    222 
    223     // Create the spatial derivatives for frame 1.
    224     for (int i = 0; i < kNumPyramidLevels; ++i) {
    225       (void) GetSpatialX(i);
    226       (void) GetSpatialY(i);
    227     }
    228     TimeLog("Created spatial derivatives");
    229 
    230     (void) GetIntegralImage();
    231     TimeLog("Got integral image!");
    232   }
    233 
    234   const int uv_frame_width_;
    235   const int uv_frame_height_;
    236 
    237   int64_t timestamp_;
    238 
    239   Image<uint8_t> image_;
    240 
    241   bool uv_data_computed_;
    242   std::unique_ptr<Image<uint8_t> > u_data_;
    243   std::unique_ptr<Image<uint8_t> > v_data_;
    244 
    245   mutable bool spatial_x_computed_[kNumPyramidLevels];
    246   mutable Image<int32_t>* spatial_x_[kNumPyramidLevels];
    247 
    248   mutable bool spatial_y_computed_[kNumPyramidLevels];
    249   mutable Image<int32_t>* spatial_y_[kNumPyramidLevels];
    250 
    251   // Mutable so the lazy initialization can work when this class is const.
    252   // Whether or not the integral image has been computed for the current image.
    253   mutable bool integral_image_computed_;
    254   mutable std::unique_ptr<IntegralImage> integral_image_;
    255 
    256   mutable bool pyramid_sqrt2_computed_[kNumPyramidLevels * 2];
    257   mutable Image<uint8_t>* pyramid_sqrt2_[kNumPyramidLevels * 2];
    258 
    259   TF_DISALLOW_COPY_AND_ASSIGN(ImageData);
    260 };
    261 
    262 }  // namespace tf_tracking
    263 
    264 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
    265