Home | History | Annotate | Download | only in native
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/java/src/main/native/graph_jni.h"
     17 
     18 #include <limits>
     19 #include <memory>
     20 #include "tensorflow/c/c_api.h"
     21 #include "tensorflow/java/src/main/native/exception_jni.h"
     22 #include "tensorflow/java/src/main/native/utils_jni.h"
     23 
     24 namespace {
     25 template <class T>
     26 T* requireHandleImpl(JNIEnv* env, jlong handle) {
     27   static_assert(sizeof(jlong) >= sizeof(T*),
     28                 "Cannot package C object pointers as a Java long");
     29   if (handle == 0) {
     30     throwException(env, kIllegalStateException,
     31                    "close() has been called on the Graph");
     32     return nullptr;
     33   }
     34   return reinterpret_cast<T*>(handle);
     35 }
     36 
     37 TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
     38   return requireHandleImpl<TF_Graph>(env, handle);
     39 }
     40 
     41 TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
     42   return requireHandleImpl<TF_Operation>(env, handle);
     43 }
     44 }  // namespace
     45 
     46 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv*, jclass) {
     47   return reinterpret_cast<jlong>(TF_NewGraph());
     48 }
     49 
     50 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
     51                                                         jlong handle) {
     52   if (handle == 0) return;
     53   TF_DeleteGraph(reinterpret_cast<TF_Graph*>(handle));
     54 }
     55 
     56 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
     57                                                             jclass clazz,
     58                                                             jlong handle,
     59                                                             jstring name) {
     60   TF_Graph* g = requireHandle(env, handle);
     61   if (g == nullptr) return 0;
     62   const char* cname = env->GetStringUTFChars(name, nullptr);
     63   TF_Operation* op = TF_GraphOperationByName(g, cname);
     64   env->ReleaseStringUTFChars(name, cname);
     65   return reinterpret_cast<jlong>(op);
     66 }
     67 
     68 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
     69     JNIEnv* env, jclass clazz, jlong handle, jint position) {
     70   TF_Graph* g = requireHandle(env, handle);
     71   if (g == nullptr) return nullptr;
     72 
     73   size_t pos = static_cast<size_t>(position);
     74   TF_Operation* operation = TF_GraphNextOperation(g, &pos);
     75   if (operation == nullptr) return nullptr;
     76 
     77   jlong handle_and_position[2];
     78   handle_and_position[0] = reinterpret_cast<jlong>(operation);
     79   handle_and_position[1] = static_cast<jlong>(pos);
     80 
     81   jlongArray rhett = env->NewLongArray(2);
     82   env->SetLongArrayRegion(rhett, 0, 2, handle_and_position);
     83   return rhett;
     84 }
     85 
     86 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
     87     JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
     88     jstring prefix) {
     89   TF_Graph* g = requireHandle(env, handle);
     90   if (g == nullptr) return;
     91 
     92   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
     93 
     94   jboolean is_copy;
     95   const char* cprefix = env->GetStringUTFChars(prefix, &is_copy);
     96   TF_ImportGraphDefOptionsSetPrefix(opts, cprefix);
     97   env->ReleaseStringUTFChars(prefix, cprefix);
     98 
     99   static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
    100   jbyte* bytes = env->GetByteArrayElements(graph_def, &is_copy);
    101   TF_Buffer* buf =
    102       TF_NewBufferFromString(bytes, env->GetArrayLength(graph_def));
    103   TF_Status* status = TF_NewStatus();
    104 
    105   TF_GraphImportGraphDef(g, buf, opts, status);
    106   throwExceptionIfNotOK(env, status);
    107   // Continue cleaning up resources even if an exception was thrown.
    108 
    109   TF_DeleteStatus(status);
    110   TF_DeleteBuffer(buf);
    111   env->ReleaseByteArrayElements(graph_def, bytes, JNI_ABORT);
    112 
    113   TF_DeleteImportGraphDefOptions(opts);
    114 }
    115 
    116 JNIEXPORT jbyteArray JNICALL
    117 Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
    118   jbyteArray ret = nullptr;
    119   TF_Graph* g = requireHandle(env, handle);
    120   if (g == nullptr) return ret;
    121 
    122   TF_Buffer* buf = TF_NewBuffer();
    123   TF_Status* status = TF_NewStatus();
    124   TF_GraphToGraphDef(g, buf, status);
    125   if (throwExceptionIfNotOK(env, status)) {
    126     // sizeof(jsize) is less than sizeof(size_t) on some platforms.
    127     if (buf->length > std::numeric_limits<jint>::max()) {
    128       throwException(env, kIndexOutOfBoundsException,
    129                      "GraphDef is too large to serialize into a byte[] array");
    130     } else {
    131       static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
    132       jint ret_len = static_cast<jint>(buf->length);
    133       ret = env->NewByteArray(ret_len);
    134       env->SetByteArrayRegion(ret, 0, ret_len,
    135                               static_cast<const jbyte*>(buf->data));
    136     }
    137   }
    138   TF_DeleteStatus(status);
    139   TF_DeleteBuffer(buf);
    140   return ret;
    141 }
    142 
    143 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
    144     JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
    145     jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
    146     jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
    147   TF_Graph* g = requireHandle(env, handle);
    148   if (g == nullptr) return nullptr;
    149 
    150   const jint ny = env->GetArrayLength(y_handles);
    151   const jint nx = env->GetArrayLength(x_handles);
    152 
    153   std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
    154   std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
    155   std::unique_ptr<TF_Output[]> dx(nullptr);
    156   std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
    157 
    158   resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
    159   resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
    160   if (dx_handles != nullptr) {
    161     if (env->GetArrayLength(dx_handles) != ny) {
    162       throwException(env, kIllegalArgumentException,
    163                      "expected %d, got %d dx handles", ny,
    164                      env->GetArrayLength(dx_handles));
    165     }
    166     dx.reset(new TF_Output[ny]);
    167     resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
    168   }
    169   if (env->ExceptionCheck()) return nullptr;
    170 
    171   const char* cprefix = nullptr;
    172   if (prefix != nullptr) {
    173     cprefix = env->GetStringUTFChars(prefix, nullptr);
    174   }
    175   TF_Status* status = TF_NewStatus();
    176   TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
    177                             status, dy.get());
    178   if (prefix != nullptr) {
    179     env->ReleaseStringUTFChars(prefix, cprefix);
    180   }
    181   if (!throwExceptionIfNotOK(env, status)) {
    182     TF_DeleteStatus(status);
    183     return nullptr;
    184   }
    185   TF_DeleteStatus(status);
    186 
    187   // returned array contains both op handles and output indices, in pair
    188   jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
    189   jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
    190   for (int i = 0, j = nx; i < nx; ++i, ++j) {
    191     TF_Output dy_output = dy.get()[i];
    192     dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
    193     dy_elems[j] = static_cast<jlong>(dy_output.index);
    194   }
    195   env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
    196 
    197   return dy_handles_and_indices;
    198 }
    199 
    200 // helper function for while loop -- constructs conditional or body subgraph
    201 jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
    202                          TF_Graph* const subgraph,
    203                          const TF_Output* const inputs,
    204                          const TF_Output* const outputs, const int ninputs,
    205                          const int noutputs) {
    206   jmethodID build_subgraph_method_id = env->GetStaticMethodID(
    207       clazz, "buildSubgraph",
    208       "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
    209   if (build_subgraph_method_id == 0) return nullptr;
    210 
    211   jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
    212 
    213   jlongArray input_handles = env->NewLongArray(ninputs);
    214   jintArray input_indices = env->NewIntArray(ninputs);
    215   jlongArray output_handles = env->NewLongArray(noutputs);
    216   jintArray output_indices = env->NewIntArray(noutputs);
    217 
    218   jlong* input_handles_elems =
    219       env->GetLongArrayElements(input_handles, nullptr);
    220   jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
    221   jlong* output_handles_elems =
    222       env->GetLongArrayElements(output_handles, nullptr);
    223   jint* output_indices_elems =
    224       env->GetIntArrayElements(output_indices, nullptr);
    225 
    226   for (int i = 0; i < ninputs; ++i) {
    227     input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
    228     input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
    229   }
    230 
    231   for (int i = 0; i < noutputs; ++i) {
    232     output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
    233     output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
    234   }
    235 
    236   env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
    237   env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
    238   env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
    239   env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
    240 
    241   // call Java code to construct the subgraph
    242   jlongArray output_handles_and_indices =
    243       (jlongArray)env->CallStaticObjectMethod(
    244           clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
    245           input_handles, input_indices, output_handles, output_indices);
    246 
    247   if (env->ExceptionOccurred()) {
    248     env->ExceptionDescribe();
    249     return nullptr;
    250   }
    251 
    252   // returned array contains both op handles and output indices, in pair
    253   return output_handles_and_indices;
    254 }
    255 
    256 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
    257     JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
    258     jintArray input_indices, jstring name, jobject cond_graph_builder,
    259     jobject body_graph_builder) {
    260   TF_Graph* g = requireHandle(env, handle);
    261   TF_Status* status = TF_NewStatus();
    262   if (g == nullptr) return nullptr;
    263 
    264   int ninputs = env->GetArrayLength(input_handles);
    265 
    266   std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
    267   resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
    268                  ninputs);
    269   if (env->ExceptionCheck()) return nullptr;
    270 
    271   // initialize while params
    272   TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
    273   throwExceptionIfNotOK(env, status);
    274 
    275   // build conditional subgraph
    276   jlongArray cond_output_handles_and_indices =
    277       buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
    278                     params.cond_inputs, &params.cond_output, params.ninputs, 1);
    279 
    280   // build body subgraph
    281   jlongArray body_output_handles_and_indices = buildSubgraph(
    282       env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
    283       params.body_outputs, params.ninputs, params.ninputs);
    284 
    285   if (cond_output_handles_and_indices == nullptr ||
    286       body_output_handles_and_indices == nullptr)
    287     return nullptr;
    288 
    289   // set cond_output param to output of the conditional subgraph
    290   jlong* cond_output_elems =
    291       env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
    292   TF_Operation* cond_output_op =
    293       requireOperationHandle(env, cond_output_elems[0]);
    294   params.cond_output = {cond_output_op,
    295                         static_cast<jint>(cond_output_elems[1])};
    296   env->ReleaseLongArrayElements(cond_output_handles_and_indices,
    297                                 cond_output_elems, 0);
    298 
    299   // set body_outputs param to outputs of the body subgraph
    300   jlong* body_output_elems =
    301       env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
    302   for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
    303     TF_Operation* body_output_op =
    304         requireOperationHandle(env, body_output_elems[i]);
    305     params.body_outputs[i] = {body_output_op,
    306                               static_cast<jint>(body_output_elems[j])};
    307   }
    308   env->ReleaseLongArrayElements(body_output_handles_and_indices,
    309                                 body_output_elems, 0);
    310 
    311   // set loop name param
    312   params.name = env->GetStringUTFChars(name, 0);
    313 
    314   // build the while loop, storing loop outputs in `outputs`
    315   std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
    316   TF_FinishWhile(&params, status, outputs.get());
    317 
    318   throwExceptionIfNotOK(env, status);
    319   TF_DeleteStatus(status);
    320 
    321   env->ReleaseStringUTFChars(name, params.name);
    322 
    323   // returned array contains both op handles and output indices, in pair
    324   jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
    325   jlong* output_elems =
    326       env->GetLongArrayElements(output_handles_and_indices, nullptr);
    327   for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
    328     TF_Output output = outputs.get()[i];
    329     output_elems[i] = reinterpret_cast<jlong>(output.oper);
    330     output_elems[j] = static_cast<jlong>(output.index);
    331   }
    332   env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
    333 
    334   return output_handles_and_indices;
    335 }
    336