Home | History | Annotate | Download | only in native
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <string.h>
     17 #include <memory>
     18 
     19 #include "tensorflow/c/c_api.h"
     20 #include "tensorflow/java/src/main/native/utils_jni.h"
     21 #include "tensorflow/java/src/main/native/exception_jni.h"
     22 #include "tensorflow/java/src/main/native/session_jni.h"
     23 
     24 namespace {
     25 TF_Session* requireHandle(JNIEnv* env, jlong handle) {
     26   static_assert(sizeof(jlong) >= sizeof(TF_Session*),
     27                 "Cannot package C object pointers as a Java long");
     28   if (handle == 0) {
     29     throwException(env, kNullPointerException,
     30                    "close() has been called on the Session");
     31     return nullptr;
     32   }
     33   return reinterpret_cast<TF_Session*>(handle);
     34 }
     35 
     36 template <class T>
     37 void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
     38                     T** dst, jint n) {
     39   if (env->ExceptionCheck()) return;
     40   jint len = env->GetArrayLength(src_array);
     41   if (len != n) {
     42     throwException(env, kIllegalArgumentException, "expected %d, got %d %s", n,
     43                    len, type);
     44     return;
     45   }
     46   jlong* src_start = env->GetLongArrayElements(src_array, nullptr);
     47   jlong* src = src_start;
     48   for (int i = 0; i < n; ++i, ++src, ++dst) {
     49     if (*src == 0) {
     50       throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
     51                      i, n);
     52       break;
     53     }
     54     *dst = reinterpret_cast<T*>(*src);
     55   }
     56   env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
     57 }
     58 
     59 void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
     60   if (buf == nullptr) return;
     61   TF_DeleteBuffer(buf);
     62 }
     63 
     64 typedef std::unique_ptr<TF_Buffer, decltype(&TF_MaybeDeleteBuffer)>
     65     unique_tf_buffer;
     66 
     67 unique_tf_buffer MakeUniqueBuffer(TF_Buffer* buf) {
     68   return unique_tf_buffer(buf, TF_MaybeDeleteBuffer);
     69 }
     70 
     71 }  // namespace
     72 
     73 JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(
     74     JNIEnv* env, jclass clazz, jlong graph_handle) {
     75   return Java_org_tensorflow_Session_allocate2(env, clazz, graph_handle,
     76                                                nullptr, nullptr);
     77 }
     78 
     79 JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
     80     JNIEnv* env, jclass clazz, jlong graph_handle, jstring target,
     81     jbyteArray config) {
     82   if (graph_handle == 0) {
     83     throwException(env, kNullPointerException, "Graph has been close()d");
     84     return 0;
     85   }
     86   TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
     87   TF_Status* status = TF_NewStatus();
     88   TF_SessionOptions* opts = TF_NewSessionOptions();
     89   jbyte* cconfig = nullptr;
     90   if (config != nullptr) {
     91     cconfig = env->GetByteArrayElements(config, nullptr);
     92     TF_SetConfig(opts, cconfig,
     93                  static_cast<size_t>(env->GetArrayLength(config)), status);
     94     if (!throwExceptionIfNotOK(env, status)) {
     95       env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
     96       TF_DeleteSessionOptions(opts);
     97       TF_DeleteStatus(status);
     98       return 0;
     99     }
    100   }
    101   const char* ctarget = nullptr;
    102   if (target != nullptr) {
    103     ctarget = env->GetStringUTFChars(target, nullptr);
    104   }
    105   TF_Session* session = TF_NewSession(graph, opts, status);
    106   if (config != nullptr) {
    107     env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
    108   }
    109   if (target != nullptr) {
    110     env->ReleaseStringUTFChars(target, ctarget);
    111   }
    112   TF_DeleteSessionOptions(opts);
    113   bool ok = throwExceptionIfNotOK(env, status);
    114   TF_DeleteStatus(status);
    115 
    116   return ok ? reinterpret_cast<jlong>(session) : 0;
    117 }
    118 
    119 JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv* env,
    120                                                           jclass clazz,
    121                                                           jlong handle) {
    122   TF_Session* session = requireHandle(env, handle);
    123   if (session == nullptr) return;
    124   TF_Status* status = TF_NewStatus();
    125   TF_CloseSession(session, status);
    126   // Result of close is ignored, delete anyway.
    127   TF_DeleteSession(session, status);
    128   throwExceptionIfNotOK(env, status);
    129   TF_DeleteStatus(status);
    130 }
    131 
    132 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
    133     JNIEnv* env, jclass clazz, jlong handle, jbyteArray jrun_options,
    134     jlongArray input_tensor_handles, jlongArray input_op_handles,
    135     jintArray input_op_indices, jlongArray output_op_handles,
    136     jintArray output_op_indices, jlongArray target_op_handles,
    137     jboolean want_run_metadata, jlongArray output_tensor_handles) {
    138   TF_Session* session = requireHandle(env, handle);
    139   if (session == nullptr) return nullptr;
    140 
    141   const jint ninputs = env->GetArrayLength(input_tensor_handles);
    142   const jint noutputs = env->GetArrayLength(output_tensor_handles);
    143   const jint ntargets = env->GetArrayLength(target_op_handles);
    144 
    145   std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
    146   std::unique_ptr<TF_Tensor* []> input_values(new TF_Tensor*[ninputs]);
    147   std::unique_ptr<TF_Output[]> outputs(new TF_Output[noutputs]);
    148   std::unique_ptr<TF_Tensor* []> output_values(new TF_Tensor*[noutputs]);
    149   std::unique_ptr<TF_Operation* []> targets(new TF_Operation*[ntargets]);
    150   unique_tf_buffer run_metadata(
    151       MakeUniqueBuffer(want_run_metadata ? TF_NewBuffer() : nullptr));
    152 
    153   resolveHandles(env, "input Tensors", input_tensor_handles, input_values.get(),
    154                  ninputs);
    155   resolveOutputs(env, "input", input_op_handles, input_op_indices, inputs.get(),
    156                  ninputs);
    157   resolveOutputs(env, "output", output_op_handles, output_op_indices,
    158                  outputs.get(), noutputs);
    159   resolveHandles(env, "target Operations", target_op_handles, targets.get(),
    160                  ntargets);
    161   if (env->ExceptionCheck()) return nullptr;
    162 
    163   TF_Status* status = TF_NewStatus();
    164 
    165   unique_tf_buffer run_options(MakeUniqueBuffer(nullptr));
    166   jbyte* jrun_options_data = nullptr;
    167   if (jrun_options != nullptr) {
    168     size_t sz = env->GetArrayLength(jrun_options);
    169     if (sz > 0) {
    170       jrun_options_data = env->GetByteArrayElements(jrun_options, nullptr);
    171       run_options.reset(
    172           TF_NewBufferFromString(static_cast<void*>(jrun_options_data), sz));
    173     }
    174   }
    175 
    176   TF_SessionRun(session, run_options.get(), inputs.get(), input_values.get(),
    177                 static_cast<int>(ninputs), outputs.get(), output_values.get(),
    178                 static_cast<int>(noutputs), targets.get(),
    179                 static_cast<int>(ntargets), run_metadata.get(), status);
    180 
    181   if (jrun_options_data != nullptr) {
    182     env->ReleaseByteArrayElements(jrun_options, jrun_options_data, JNI_ABORT);
    183   }
    184 
    185   if (!throwExceptionIfNotOK(env, status)) {
    186     TF_DeleteStatus(status);
    187     return nullptr;
    188   }
    189   jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
    190   for (int i = 0; i < noutputs; ++i) {
    191     t[i] = reinterpret_cast<jlong>(output_values[i]);
    192   }
    193   env->ReleaseLongArrayElements(output_tensor_handles, t, 0);
    194 
    195   jbyteArray ret = nullptr;
    196   if (run_metadata != nullptr) {
    197     ret = env->NewByteArray(run_metadata->length);
    198     env->SetByteArrayRegion(ret, 0, run_metadata->length,
    199                             reinterpret_cast<const jbyte*>(run_metadata->data));
    200   }
    201   TF_DeleteStatus(status);
    202   return ret;
    203 }
    204