Home | History | Annotate | Download | only in object_tracking
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <android/log.h>
     17 #include <jni.h>
     18 #include <stdint.h>
     19 #include <stdlib.h>
     20 #include <string.h>
     21 #include <cstdint>
     22 
     23 #include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
     24 #include "tensorflow/examples/android/jni/object_tracking/image.h"
     25 #include "tensorflow/examples/android/jni/object_tracking/jni_utils.h"
     26 #include "tensorflow/examples/android/jni/object_tracking/time_log.h"
     27 
     28 #include "tensorflow/examples/android/jni/object_tracking/config.h"
     29 #include "tensorflow/examples/android/jni/object_tracking/object_tracker.h"
     30 
     31 namespace tf_tracking {
     32 
     33 #define OBJECT_TRACKER_METHOD(METHOD_NAME) \
     34   Java_org_tensorflow_demo_tracking_ObjectTracker_##METHOD_NAME  // NOLINT
     35 
     36 JniLongField object_tracker_field("nativeObjectTracker");
     37 
     38 ObjectTracker* get_object_tracker(JNIEnv* env, jobject thiz) {
     39   ObjectTracker* const object_tracker =
     40       reinterpret_cast<ObjectTracker*>(object_tracker_field.get(env, thiz));
     41   CHECK_ALWAYS(object_tracker != NULL, "null object tracker!");
     42   return object_tracker;
     43 }
     44 
     45 void set_object_tracker(JNIEnv* env, jobject thiz,
     46                         const ObjectTracker* object_tracker) {
     47   object_tracker_field.set(env, thiz,
     48                            reinterpret_cast<intptr_t>(object_tracker));
     49 }
     50 
     51 #ifdef __cplusplus
     52 extern "C" {
     53 #endif
     54 JNIEXPORT
     55 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
     56                                                jint width, jint height,
     57                                                jboolean always_track);
     58 
     59 JNIEXPORT
     60 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
     61                                                         jobject thiz);
     62 
     63 JNIEXPORT
     64 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
     65     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
     66     jfloat x2, jfloat y2, jbyteArray frame_data);
     67 
     68 JNIEXPORT
     69 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
     70     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
     71     jfloat x2, jfloat y2, jlong timestamp);
     72 
     73 JNIEXPORT
     74 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
     75     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
     76     jfloat x2, jfloat y2);
     77 
     78 JNIEXPORT
     79 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
     80                                                    jstring object_id);
     81 
     82 JNIEXPORT
     83 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
     84                                                         jobject thiz,
     85                                                         jstring object_id);
     86 
     87 JNIEXPORT
     88 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
     89                                                         jobject thiz,
     90                                                         jstring object_id);
     91 
     92 JNIEXPORT
     93 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
     94                                                             jobject thiz,
     95                                                             jstring object_id);
     96 
     97 JNIEXPORT
     98 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
     99                                                     jstring object_id);
    100 
    101 JNIEXPORT
    102 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
    103     JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array);
    104 
    105 JNIEXPORT
    106 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
    107                                                     jbyteArray y_data,
    108                                                     jbyteArray uv_data,
    109                                                     jlong timestamp,
    110                                                     jfloatArray vg_matrix_2x3);
    111 
    112 JNIEXPORT
    113 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
    114                                                  jstring object_id);
    115 
    116 JNIEXPORT
    117 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
    118     JNIEnv* env, jobject thiz, jfloat scale_factor);
    119 
    120 JNIEXPORT
    121 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
    122     JNIEnv* env, jobject thiz, jboolean only_found_);
    123 
    124 JNIEXPORT
    125 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
    126     JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
    127     jfloat position_y1, jfloat position_x2, jfloat position_y2,
    128     jfloatArray delta);
    129 
    130 JNIEXPORT
    131 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(JNIEnv* env, jobject obj,
    132                                                jint view_width,
    133                                                jint view_height,
    134                                                jfloatArray delta);
    135 
    136 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
    137     JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
    138     jbyteArray input, jint factor, jbyteArray output);
    139 
    140 #ifdef __cplusplus
    141 }
    142 #endif
    143 
    144 JNIEXPORT
    145 void JNICALL OBJECT_TRACKER_METHOD(initNative)(JNIEnv* env, jobject thiz,
    146                                                jint width, jint height,
    147                                                jboolean always_track) {
    148   LOGI("Initializing object tracker. %dx%d @%p", width, height, thiz);
    149   const Size image_size(width, height);
    150   TrackerConfig* const tracker_config = new TrackerConfig(image_size);
    151   tracker_config->always_track = always_track;
    152 
    153   // XXX detector
    154   ObjectTracker* const tracker = new ObjectTracker(tracker_config, NULL);
    155   set_object_tracker(env, thiz, tracker);
    156   LOGI("Initialized!");
    157 
    158   CHECK_ALWAYS(get_object_tracker(env, thiz) == tracker,
    159                "Failure to set hand tracker!");
    160 }
    161 
    162 JNIEXPORT
    163 void JNICALL OBJECT_TRACKER_METHOD(releaseMemoryNative)(JNIEnv* env,
    164                                                         jobject thiz) {
    165   delete get_object_tracker(env, thiz);
    166   set_object_tracker(env, thiz, NULL);
    167 }
    168 
    169 JNIEXPORT
    170 void JNICALL OBJECT_TRACKER_METHOD(registerNewObjectWithAppearanceNative)(
    171     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
    172     jfloat x2, jfloat y2, jbyteArray frame_data) {
    173   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    174 
    175   LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
    176        x2, y2);
    177 
    178   jboolean iCopied = JNI_FALSE;
    179 
    180   // Copy image into currFrame.
    181   jbyte* pixels = env->GetByteArrayElements(frame_data, &iCopied);
    182 
    183   BoundingBox bounding_box(x1, y1, x2, y2);
    184   get_object_tracker(env, thiz)->RegisterNewObjectWithAppearance(
    185       id_str, reinterpret_cast<const uint8_t*>(pixels), bounding_box);
    186 
    187   env->ReleaseByteArrayElements(frame_data, pixels, JNI_ABORT);
    188 
    189   env->ReleaseStringUTFChars(object_id, id_str);
    190 }
    191 
    192 JNIEXPORT
    193 void JNICALL OBJECT_TRACKER_METHOD(setPreviousPositionNative)(
    194     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
    195     jfloat x2, jfloat y2, jlong timestamp) {
    196   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    197 
    198   LOGI(
    199       "Registering the position of %s at %.2f,%.2f,%.2f,%.2f"
    200       " at time %lld",
    201       id_str, x1, y1, x2, y2, static_cast<int64_t>(timestamp));
    202 
    203   get_object_tracker(env, thiz)->SetPreviousPositionOfObject(
    204       id_str, BoundingBox(x1, y1, x2, y2), timestamp);
    205 
    206   env->ReleaseStringUTFChars(object_id, id_str);
    207 }
    208 
    209 JNIEXPORT
    210 void JNICALL OBJECT_TRACKER_METHOD(setCurrentPositionNative)(
    211     JNIEnv* env, jobject thiz, jstring object_id, jfloat x1, jfloat y1,
    212     jfloat x2, jfloat y2) {
    213   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    214 
    215   LOGI("Registering the position of %s at %.2f,%.2f,%.2f,%.2f", id_str, x1, y1,
    216        x2, y2);
    217 
    218   get_object_tracker(env, thiz)->SetCurrentPositionOfObject(
    219       id_str, BoundingBox(x1, y1, x2, y2));
    220 
    221   env->ReleaseStringUTFChars(object_id, id_str);
    222 }
    223 
    224 JNIEXPORT
    225 jboolean JNICALL OBJECT_TRACKER_METHOD(haveObject)(JNIEnv* env, jobject thiz,
    226                                                    jstring object_id) {
    227   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    228 
    229   const bool haveObject = get_object_tracker(env, thiz)->HaveObject(id_str);
    230   env->ReleaseStringUTFChars(object_id, id_str);
    231   return haveObject;
    232 }
    233 
    234 JNIEXPORT
    235 jboolean JNICALL OBJECT_TRACKER_METHOD(isObjectVisible)(JNIEnv* env,
    236                                                         jobject thiz,
    237                                                         jstring object_id) {
    238   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    239 
    240   const bool visible = get_object_tracker(env, thiz)->IsObjectVisible(id_str);
    241   env->ReleaseStringUTFChars(object_id, id_str);
    242   return visible;
    243 }
    244 
    245 JNIEXPORT
    246 jstring JNICALL OBJECT_TRACKER_METHOD(getModelIdNative)(JNIEnv* env,
    247                                                         jobject thiz,
    248                                                         jstring object_id) {
    249   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    250   const TrackedObject* const object =
    251       get_object_tracker(env, thiz)->GetObject(id_str);
    252   env->ReleaseStringUTFChars(object_id, id_str);
    253   jstring model_name = env->NewStringUTF(object->GetModel()->GetName().c_str());
    254   return model_name;
    255 }
    256 
    257 JNIEXPORT
    258 jfloat JNICALL OBJECT_TRACKER_METHOD(getCurrentCorrelation)(JNIEnv* env,
    259                                                             jobject thiz,
    260                                                             jstring object_id) {
    261   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    262 
    263   const float correlation =
    264       get_object_tracker(env, thiz)->GetObject(id_str)->GetCorrelation();
    265   env->ReleaseStringUTFChars(object_id, id_str);
    266   return correlation;
    267 }
    268 
    269 JNIEXPORT
    270 jfloat JNICALL OBJECT_TRACKER_METHOD(getMatchScore)(JNIEnv* env, jobject thiz,
    271                                                     jstring object_id) {
    272   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    273 
    274   const float match_score =
    275       get_object_tracker(env, thiz)->GetObject(id_str)->GetMatchScore().value;
    276   env->ReleaseStringUTFChars(object_id, id_str);
    277   return match_score;
    278 }
    279 
    280 JNIEXPORT
    281 void JNICALL OBJECT_TRACKER_METHOD(getTrackedPositionNative)(
    282     JNIEnv* env, jobject thiz, jstring object_id, jfloatArray rect_array) {
    283   jboolean iCopied = JNI_FALSE;
    284   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    285 
    286   const BoundingBox bounding_box =
    287       get_object_tracker(env, thiz)->GetObject(id_str)->GetPosition();
    288   env->ReleaseStringUTFChars(object_id, id_str);
    289 
    290   jfloat* rect = env->GetFloatArrayElements(rect_array, &iCopied);
    291   bounding_box.CopyToArray(reinterpret_cast<float*>(rect));
    292   env->ReleaseFloatArrayElements(rect_array, rect, 0);
    293 }
    294 
    295 JNIEXPORT
    296 void JNICALL OBJECT_TRACKER_METHOD(nextFrameNative)(JNIEnv* env, jobject thiz,
    297                                                     jbyteArray y_data,
    298                                                     jbyteArray uv_data,
    299                                                     jlong timestamp,
    300                                                     jfloatArray vg_matrix_2x3) {
    301   TimeLog("Starting object tracker");
    302 
    303   jboolean iCopied = JNI_FALSE;
    304 
    305   float vision_gyro_matrix_array[6];
    306   jfloat* jmat = NULL;
    307 
    308   if (vg_matrix_2x3 != NULL) {
    309     // Copy the alignment matrix into a float array.
    310     jmat = env->GetFloatArrayElements(vg_matrix_2x3, &iCopied);
    311     for (int i = 0; i < 6; ++i) {
    312       vision_gyro_matrix_array[i] = static_cast<float>(jmat[i]);
    313     }
    314   }
    315   // Copy image into currFrame.
    316   jbyte* pixels = env->GetByteArrayElements(y_data, &iCopied);
    317   jbyte* uv_pixels =
    318       uv_data != NULL ? env->GetByteArrayElements(uv_data, &iCopied) : NULL;
    319 
    320   TimeLog("Got elements");
    321 
    322   // Add the frame to the object tracker object.
    323   get_object_tracker(env, thiz)->NextFrame(
    324       reinterpret_cast<uint8_t*>(pixels), reinterpret_cast<uint8_t*>(uv_pixels),
    325       timestamp, vg_matrix_2x3 != NULL ? vision_gyro_matrix_array : NULL);
    326 
    327   env->ReleaseByteArrayElements(y_data, pixels, JNI_ABORT);
    328 
    329   if (uv_data != NULL) {
    330     env->ReleaseByteArrayElements(uv_data, uv_pixels, JNI_ABORT);
    331   }
    332 
    333   if (vg_matrix_2x3 != NULL) {
    334     env->ReleaseFloatArrayElements(vg_matrix_2x3, jmat, JNI_ABORT);
    335   }
    336 
    337   TimeLog("Released elements");
    338 
    339   PrintTimeLog();
    340   ResetTimeLog();
    341 }
    342 
    343 JNIEXPORT
    344 void JNICALL OBJECT_TRACKER_METHOD(forgetNative)(JNIEnv* env, jobject thiz,
    345                                                  jstring object_id) {
    346   const char* const id_str = env->GetStringUTFChars(object_id, 0);
    347 
    348   get_object_tracker(env, thiz)->ForgetTarget(id_str);
    349 
    350   env->ReleaseStringUTFChars(object_id, id_str);
    351 }
    352 
    353 JNIEXPORT
    354 jfloatArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsNative)(
    355     JNIEnv* env, jobject thiz, jboolean only_found) {
    356   jfloat keypoint_arr[kMaxKeypoints * kKeypointStep];
    357 
    358   const int number_of_keypoints =
    359       get_object_tracker(env, thiz)->GetKeypoints(only_found, keypoint_arr);
    360 
    361   // Create and return the array that will be passed back to Java.
    362   jfloatArray keypoints =
    363       env->NewFloatArray(number_of_keypoints * kKeypointStep);
    364   if (keypoints == NULL) {
    365     LOGE("null array!");
    366     return NULL;
    367   }
    368   env->SetFloatArrayRegion(keypoints, 0, number_of_keypoints * kKeypointStep,
    369                            keypoint_arr);
    370 
    371   return keypoints;
    372 }
    373 
    374 JNIEXPORT
    375 jbyteArray JNICALL OBJECT_TRACKER_METHOD(getKeypointsPacked)(
    376     JNIEnv* env, jobject thiz, jfloat scale_factor) {
    377   // 2 bytes to a uint16_t and two pairs of xy coordinates per keypoint.
    378   const int bytes_per_keypoint = sizeof(uint16_t) * 2 * 2;
    379   jbyte keypoint_arr[kMaxKeypoints * bytes_per_keypoint];
    380 
    381   const int number_of_keypoints =
    382       get_object_tracker(env, thiz)->GetKeypointsPacked(
    383           reinterpret_cast<uint16_t*>(keypoint_arr), scale_factor);
    384 
    385   // Create and return the array that will be passed back to Java.
    386   jbyteArray keypoints =
    387       env->NewByteArray(number_of_keypoints * bytes_per_keypoint);
    388 
    389   if (keypoints == NULL) {
    390     LOGE("null array!");
    391     return NULL;
    392   }
    393 
    394   env->SetByteArrayRegion(
    395       keypoints, 0, number_of_keypoints * bytes_per_keypoint, keypoint_arr);
    396 
    397   return keypoints;
    398 }
    399 
    400 JNIEXPORT
    401 void JNICALL OBJECT_TRACKER_METHOD(getCurrentPositionNative)(
    402     JNIEnv* env, jobject thiz, jlong timestamp, jfloat position_x1,
    403     jfloat position_y1, jfloat position_x2, jfloat position_y2,
    404     jfloatArray delta) {
    405   jfloat point_arr[4];
    406 
    407   const BoundingBox new_position = get_object_tracker(env, thiz)->TrackBox(
    408       BoundingBox(position_x1, position_y1, position_x2, position_y2),
    409       timestamp);
    410 
    411   new_position.CopyToArray(point_arr);
    412   env->SetFloatArrayRegion(delta, 0, 4, point_arr);
    413 }
    414 
    415 JNIEXPORT
    416 void JNICALL OBJECT_TRACKER_METHOD(drawNative)(
    417     JNIEnv* env, jobject thiz, jint view_width, jint view_height,
    418     jfloatArray frame_to_canvas_arr) {
    419   ObjectTracker* object_tracker = get_object_tracker(env, thiz);
    420   if (object_tracker != NULL) {
    421     jfloat* frame_to_canvas =
    422         env->GetFloatArrayElements(frame_to_canvas_arr, NULL);
    423 
    424     object_tracker->Draw(view_width, view_height, frame_to_canvas);
    425     env->ReleaseFloatArrayElements(frame_to_canvas_arr, frame_to_canvas,
    426                                    JNI_ABORT);
    427   }
    428 }
    429 
    430 JNIEXPORT void JNICALL OBJECT_TRACKER_METHOD(downsampleImageNative)(
    431     JNIEnv* env, jobject thiz, jint width, jint height, jint row_stride,
    432     jbyteArray input, jint factor, jbyteArray output) {
    433   if (input == NULL || output == NULL) {
    434     LOGW("Received null arrays, hopefully this is a test!");
    435     return;
    436   }
    437 
    438   jbyte* const input_array = env->GetByteArrayElements(input, 0);
    439   jbyte* const output_array = env->GetByteArrayElements(output, 0);
    440 
    441   {
    442     tf_tracking::Image<uint8_t> full_image(
    443         width, height, reinterpret_cast<uint8_t*>(input_array), false);
    444 
    445     const int new_width = (width + factor - 1) / factor;
    446     const int new_height = (height + factor - 1) / factor;
    447 
    448     tf_tracking::Image<uint8_t> downsampled_image(
    449         new_width, new_height, reinterpret_cast<uint8_t*>(output_array), false);
    450 
    451     downsampled_image.DownsampleAveraged(
    452         reinterpret_cast<uint8_t*>(input_array), row_stride, factor);
    453   }
    454 
    455   env->ReleaseByteArrayElements(input, input_array, JNI_ABORT);
    456   env->ReleaseByteArrayElements(output, output_array, 0);
    457 }
    458 
    459 }  // namespace tf_tracking
    460