1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <limits> 17 #include <memory> 18 19 #include "tensorflow/c/c_api.h" 20 #include "tensorflow/java/src/main/native/exception_jni.h" 21 #include "tensorflow/java/src/main/native/saved_model_bundle_jni.h" 22 23 JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( 24 JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags, 25 jbyteArray config, jbyteArray run_options) { 26 TF_Status* status = TF_NewStatus(); 27 jobject bundle = nullptr; 28 29 // allocate parameters for TF_LoadSessionFromSavedModel 30 TF_SessionOptions* opts = TF_NewSessionOptions(); 31 if (config != nullptr) { 32 size_t sz = env->GetArrayLength(config); 33 if (sz > 0) { 34 jbyte* config_data = env->GetByteArrayElements(config, nullptr); 35 TF_SetConfig(opts, static_cast<void*>(config_data), sz, status); 36 env->ReleaseByteArrayElements(config, config_data, JNI_ABORT); 37 if (!throwExceptionIfNotOK(env, status)) { 38 TF_DeleteSessionOptions(opts); 39 TF_DeleteStatus(status); 40 return nullptr; 41 } 42 } 43 } 44 TF_Buffer* crun_options = nullptr; 45 if (run_options != nullptr) { 46 size_t sz = env->GetArrayLength(run_options); 47 if (sz > 0) { 48 jbyte* run_options_data = env->GetByteArrayElements(run_options, nullptr); 49 crun_options = 50 TF_NewBufferFromString(static_cast<void*>(run_options_data), sz); 51 env->ReleaseByteArrayElements(run_options, run_options_data, JNI_ABORT); 52 } 53 } 54 const char* cexport_dir = env->GetStringUTFChars(export_dir, nullptr); 55 std::unique_ptr<const char* []> tags_ptrs; 56 size_t tags_len = env->GetArrayLength(tags); 57 tags_ptrs.reset(new const char*[tags_len]); 58 for (size_t i = 0; i < tags_len; ++i) { 59 jstring tag = static_cast<jstring>(env->GetObjectArrayElement(tags, i)); 60 tags_ptrs[i] = env->GetStringUTFChars(tag, nullptr); 61 env->DeleteLocalRef(tag); 62 } 63 64 // load the session 65 TF_Graph* graph = TF_NewGraph(); 66 TF_Buffer* metagraph_def = TF_NewBuffer(); 67 TF_Session* session = TF_LoadSessionFromSavedModel( 68 opts, crun_options, cexport_dir, tags_ptrs.get(), tags_len, graph, 69 metagraph_def, status); 70 71 // release the parameters 72 TF_DeleteSessionOptions(opts); 73 if (crun_options != nullptr) { 74 TF_DeleteBuffer(crun_options); 75 } 76 env->ReleaseStringUTFChars(export_dir, cexport_dir); 77 for (size_t i = 0; i < tags_len; ++i) { 78 jstring tag = static_cast<jstring>(env->GetObjectArrayElement(tags, i)); 79 env->ReleaseStringUTFChars(tag, tags_ptrs[i]); 80 env->DeleteLocalRef(tag); 81 } 82 83 // handle the result 84 if (throwExceptionIfNotOK(env, status)) { 85 // sizeof(jsize) is less than sizeof(size_t) on some platforms. 86 if (metagraph_def->length > std::numeric_limits<jint>::max()) { 87 throwException( 88 env, kIndexOutOfBoundsException, 89 "MetaGraphDef is too large to serialize into a byte[] array"); 90 } else { 91 static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type"); 92 jint jmetagraph_len = static_cast<jint>(metagraph_def->length); 93 jbyteArray jmetagraph_def = env->NewByteArray(jmetagraph_len); 94 env->SetByteArrayRegion(jmetagraph_def, 0, jmetagraph_len, 95 static_cast<const jbyte*>(metagraph_def->data)); 96 97 jmethodID method = env->GetStaticMethodID( 98 clazz, "fromHandle", "(JJ[B)Lorg/tensorflow/SavedModelBundle;"); 99 bundle = env->CallStaticObjectMethod( 100 clazz, method, reinterpret_cast<jlong>(graph), 101 reinterpret_cast<jlong>(session), jmetagraph_def); 102 graph = nullptr; 103 session = nullptr; 104 env->DeleteLocalRef(jmetagraph_def); 105 } 106 } 107 108 if (session != nullptr) { 109 TF_CloseSession(session, status); 110 // Result of close is ignored, delete anyway. 111 TF_DeleteSession(session, status); 112 } 113 if (graph != nullptr) { 114 TF_DeleteGraph(graph); 115 } 116 TF_DeleteBuffer(metagraph_def); 117 TF_DeleteStatus(status); 118 119 return bundle; 120 } 121