Home | History | Annotate | Download | only in native
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <limits>
     17 #include <memory>
     18 
     19 #include "tensorflow/c/c_api.h"
     20 #include "tensorflow/java/src/main/native/exception_jni.h"
     21 #include "tensorflow/java/src/main/native/saved_model_bundle_jni.h"
     22 
     23 JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
     24     JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
     25     jbyteArray config, jbyteArray run_options) {
     26   TF_Status* status = TF_NewStatus();
     27   jobject bundle = nullptr;
     28 
     29   // allocate parameters for TF_LoadSessionFromSavedModel
     30   TF_SessionOptions* opts = TF_NewSessionOptions();
     31   if (config != nullptr) {
     32     size_t sz = env->GetArrayLength(config);
     33     if (sz > 0) {
     34       jbyte* config_data = env->GetByteArrayElements(config, nullptr);
     35       TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
     36       env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
     37       if (!throwExceptionIfNotOK(env, status)) {
     38         TF_DeleteSessionOptions(opts);
     39         TF_DeleteStatus(status);
     40         return nullptr;
     41       }
     42     }
     43   }
     44   TF_Buffer* crun_options = nullptr;
     45   if (run_options != nullptr) {
     46     size_t sz = env->GetArrayLength(run_options);
     47     if (sz > 0) {
     48       jbyte* run_options_data = env->GetByteArrayElements(run_options, nullptr);
     49       crun_options =
     50           TF_NewBufferFromString(static_cast<void*>(run_options_data), sz);
     51       env->ReleaseByteArrayElements(run_options, run_options_data, JNI_ABORT);
     52     }
     53   }
     54   const char* cexport_dir = env->GetStringUTFChars(export_dir, nullptr);
     55   std::unique_ptr<const char* []> tags_ptrs;
     56   size_t tags_len = env->GetArrayLength(tags);
     57   tags_ptrs.reset(new const char*[tags_len]);
     58   for (size_t i = 0; i < tags_len; ++i) {
     59     jstring tag = static_cast<jstring>(env->GetObjectArrayElement(tags, i));
     60     tags_ptrs[i] = env->GetStringUTFChars(tag, nullptr);
     61     env->DeleteLocalRef(tag);
     62   }
     63 
     64   // load the session
     65   TF_Graph* graph = TF_NewGraph();
     66   TF_Buffer* metagraph_def = TF_NewBuffer();
     67   TF_Session* session = TF_LoadSessionFromSavedModel(
     68       opts, crun_options, cexport_dir, tags_ptrs.get(), tags_len, graph,
     69       metagraph_def, status);
     70 
     71   // release the parameters
     72   TF_DeleteSessionOptions(opts);
     73   if (crun_options != nullptr) {
     74     TF_DeleteBuffer(crun_options);
     75   }
     76   env->ReleaseStringUTFChars(export_dir, cexport_dir);
     77   for (size_t i = 0; i < tags_len; ++i) {
     78     jstring tag = static_cast<jstring>(env->GetObjectArrayElement(tags, i));
     79     env->ReleaseStringUTFChars(tag, tags_ptrs[i]);
     80     env->DeleteLocalRef(tag);
     81   }
     82 
     83   // handle the result
     84   if (throwExceptionIfNotOK(env, status)) {
     85     // sizeof(jsize) is less than sizeof(size_t) on some platforms.
     86     if (metagraph_def->length > std::numeric_limits<jint>::max()) {
     87       throwException(
     88           env, kIndexOutOfBoundsException,
     89           "MetaGraphDef is too large to serialize into a byte[] array");
     90     } else {
     91       static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
     92       jint jmetagraph_len = static_cast<jint>(metagraph_def->length);
     93       jbyteArray jmetagraph_def = env->NewByteArray(jmetagraph_len);
     94       env->SetByteArrayRegion(jmetagraph_def, 0, jmetagraph_len,
     95                               static_cast<const jbyte*>(metagraph_def->data));
     96 
     97       jmethodID method = env->GetStaticMethodID(
     98           clazz, "fromHandle", "(JJ[B)Lorg/tensorflow/SavedModelBundle;");
     99       bundle = env->CallStaticObjectMethod(
    100           clazz, method, reinterpret_cast<jlong>(graph),
    101           reinterpret_cast<jlong>(session), jmetagraph_def);
    102       graph = nullptr;
    103       session = nullptr;
    104       env->DeleteLocalRef(jmetagraph_def);
    105     }
    106   }
    107 
    108   if (session != nullptr) {
    109     TF_CloseSession(session, status);
    110     // Result of close is ignored, delete anyway.
    111     TF_DeleteSession(session, status);
    112   }
    113   if (graph != nullptr) {
    114     TF_DeleteGraph(graph);
    115   }
    116   TF_DeleteBuffer(metagraph_def);
    117   TF_DeleteStatus(status);
    118 
    119   return bundle;
    120 }
    121