1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/java/src/main/native/graph_jni.h" 17 18 #include <limits> 19 #include <memory> 20 #include "tensorflow/c/c_api.h" 21 #include "tensorflow/java/src/main/native/exception_jni.h" 22 #include "tensorflow/java/src/main/native/utils_jni.h" 23 24 namespace { 25 template <class T> 26 T* requireHandleImpl(JNIEnv* env, jlong handle) { 27 static_assert(sizeof(jlong) >= sizeof(T*), 28 "Cannot package C object pointers as a Java long"); 29 if (handle == 0) { 30 throwException(env, kIllegalStateException, 31 "close() has been called on the Graph"); 32 return nullptr; 33 } 34 return reinterpret_cast<T*>(handle); 35 } 36 37 TF_Graph* requireHandle(JNIEnv* env, jlong handle) { 38 return requireHandleImpl<TF_Graph>(env, handle); 39 } 40 41 TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) { 42 return requireHandleImpl<TF_Operation>(env, handle); 43 } 44 } // namespace 45 46 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv*, jclass) { 47 return reinterpret_cast<jlong>(TF_NewGraph()); 48 } 49 50 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass, 51 jlong handle) { 52 if (handle == 0) return; 53 TF_DeleteGraph(reinterpret_cast<TF_Graph*>(handle)); 54 } 55 56 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env, 57 jclass clazz, 58 jlong handle, 59 jstring name) { 60 TF_Graph* g = requireHandle(env, handle); 61 if (g == nullptr) return 0; 62 const char* cname = env->GetStringUTFChars(name, nullptr); 63 TF_Operation* op = TF_GraphOperationByName(g, cname); 64 env->ReleaseStringUTFChars(name, cname); 65 return reinterpret_cast<jlong>(op); 66 } 67 68 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation( 69 JNIEnv* env, jclass clazz, jlong handle, jint position) { 70 TF_Graph* g = requireHandle(env, handle); 71 if (g == nullptr) return nullptr; 72 73 size_t pos = static_cast<size_t>(position); 74 TF_Operation* operation = TF_GraphNextOperation(g, &pos); 75 if (operation == nullptr) return nullptr; 76 77 jlong handle_and_position[2]; 78 handle_and_position[0] = reinterpret_cast<jlong>(operation); 79 handle_and_position[1] = static_cast<jlong>(pos); 80 81 jlongArray rhett = env->NewLongArray(2); 82 env->SetLongArrayRegion(rhett, 0, 2, handle_and_position); 83 return rhett; 84 } 85 86 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef( 87 JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def, 88 jstring prefix) { 89 TF_Graph* g = requireHandle(env, handle); 90 if (g == nullptr) return; 91 92 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); 93 94 jboolean is_copy; 95 const char* cprefix = env->GetStringUTFChars(prefix, &is_copy); 96 TF_ImportGraphDefOptionsSetPrefix(opts, cprefix); 97 env->ReleaseStringUTFChars(prefix, cprefix); 98 99 static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type"); 100 jbyte* bytes = env->GetByteArrayElements(graph_def, &is_copy); 101 TF_Buffer* buf = 102 TF_NewBufferFromString(bytes, env->GetArrayLength(graph_def)); 103 TF_Status* status = TF_NewStatus(); 104 105 TF_GraphImportGraphDef(g, buf, opts, status); 106 throwExceptionIfNotOK(env, status); 107 // Continue cleaning up resources even if an exception was thrown. 108 109 TF_DeleteStatus(status); 110 TF_DeleteBuffer(buf); 111 env->ReleaseByteArrayElements(graph_def, bytes, JNI_ABORT); 112 113 TF_DeleteImportGraphDefOptions(opts); 114 } 115 116 JNIEXPORT jbyteArray JNICALL 117 Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { 118 jbyteArray ret = nullptr; 119 TF_Graph* g = requireHandle(env, handle); 120 if (g == nullptr) return ret; 121 122 TF_Buffer* buf = TF_NewBuffer(); 123 TF_Status* status = TF_NewStatus(); 124 TF_GraphToGraphDef(g, buf, status); 125 if (throwExceptionIfNotOK(env, status)) { 126 // sizeof(jsize) is less than sizeof(size_t) on some platforms. 127 if (buf->length > std::numeric_limits<jint>::max()) { 128 throwException(env, kIndexOutOfBoundsException, 129 "GraphDef is too large to serialize into a byte[] array"); 130 } else { 131 static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type"); 132 jint ret_len = static_cast<jint>(buf->length); 133 ret = env->NewByteArray(ret_len); 134 env->SetByteArrayRegion(ret, 0, ret_len, 135 static_cast<const jbyte*>(buf->data)); 136 } 137 } 138 TF_DeleteStatus(status); 139 TF_DeleteBuffer(buf); 140 return ret; 141 } 142 143 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( 144 JNIEnv* env, jclass clazz, jlong handle, jstring prefix, 145 jlongArray y_handles, jintArray y_indices, jlongArray x_handles, 146 jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { 147 TF_Graph* g = requireHandle(env, handle); 148 if (g == nullptr) return nullptr; 149 150 const jint ny = env->GetArrayLength(y_handles); 151 const jint nx = env->GetArrayLength(x_handles); 152 153 std::unique_ptr<TF_Output[]> y(new TF_Output[ny]); 154 std::unique_ptr<TF_Output[]> x(new TF_Output[nx]); 155 std::unique_ptr<TF_Output[]> dx(nullptr); 156 std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]); 157 158 resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); 159 resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); 160 if (dx_handles != nullptr) { 161 if (env->GetArrayLength(dx_handles) != ny) { 162 throwException(env, kIllegalArgumentException, 163 "expected %d, got %d dx handles", ny, 164 env->GetArrayLength(dx_handles)); 165 } 166 dx.reset(new TF_Output[ny]); 167 resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); 168 } 169 if (env->ExceptionCheck()) return nullptr; 170 171 const char* cprefix = nullptr; 172 if (prefix != nullptr) { 173 cprefix = env->GetStringUTFChars(prefix, nullptr); 174 } 175 TF_Status* status = TF_NewStatus(); 176 TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), 177 status, dy.get()); 178 if (prefix != nullptr) { 179 env->ReleaseStringUTFChars(prefix, cprefix); 180 } 181 if (!throwExceptionIfNotOK(env, status)) { 182 TF_DeleteStatus(status); 183 return nullptr; 184 } 185 TF_DeleteStatus(status); 186 187 // returned array contains both op handles and output indices, in pair 188 jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); 189 jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); 190 for (int i = 0, j = nx; i < nx; ++i, ++j) { 191 TF_Output dy_output = dy.get()[i]; 192 dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper); 193 dy_elems[j] = static_cast<jlong>(dy_output.index); 194 } 195 env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); 196 197 return dy_handles_and_indices; 198 } 199 200 // helper function for while loop -- constructs conditional or body subgraph 201 jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder, 202 TF_Graph* const subgraph, 203 const TF_Output* const inputs, 204 const TF_Output* const outputs, const int ninputs, 205 const int noutputs) { 206 jmethodID build_subgraph_method_id = env->GetStaticMethodID( 207 clazz, "buildSubgraph", 208 "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J"); 209 if (build_subgraph_method_id == 0) return nullptr; 210 211 jlong subgraph_handle = reinterpret_cast<jlong>(subgraph); 212 213 jlongArray input_handles = env->NewLongArray(ninputs); 214 jintArray input_indices = env->NewIntArray(ninputs); 215 jlongArray output_handles = env->NewLongArray(noutputs); 216 jintArray output_indices = env->NewIntArray(noutputs); 217 218 jlong* input_handles_elems = 219 env->GetLongArrayElements(input_handles, nullptr); 220 jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr); 221 jlong* output_handles_elems = 222 env->GetLongArrayElements(output_handles, nullptr); 223 jint* output_indices_elems = 224 env->GetIntArrayElements(output_indices, nullptr); 225 226 for (int i = 0; i < ninputs; ++i) { 227 input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper); 228 input_indices_elems[i] = static_cast<jint>((inputs[i]).index); 229 } 230 231 for (int i = 0; i < noutputs; ++i) { 232 output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper); 233 output_indices_elems[i] = static_cast<jint>((outputs[i]).index); 234 } 235 236 env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0); 237 env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0); 238 env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0); 239 env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0); 240 241 // call Java code to construct the subgraph 242 jlongArray output_handles_and_indices = 243 (jlongArray)env->CallStaticObjectMethod( 244 clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle, 245 input_handles, input_indices, output_handles, output_indices); 246 247 if (env->ExceptionOccurred()) { 248 env->ExceptionDescribe(); 249 return nullptr; 250 } 251 252 // returned array contains both op handles and output indices, in pair 253 return output_handles_and_indices; 254 } 255 256 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop( 257 JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles, 258 jintArray input_indices, jstring name, jobject cond_graph_builder, 259 jobject body_graph_builder) { 260 TF_Graph* g = requireHandle(env, handle); 261 TF_Status* status = TF_NewStatus(); 262 if (g == nullptr) return nullptr; 263 264 int ninputs = env->GetArrayLength(input_handles); 265 266 std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]); 267 resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(), 268 ninputs); 269 if (env->ExceptionCheck()) return nullptr; 270 271 // initialize while params 272 TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status); 273 throwExceptionIfNotOK(env, status); 274 275 // build conditional subgraph 276 jlongArray cond_output_handles_and_indices = 277 buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph, 278 params.cond_inputs, ¶ms.cond_output, params.ninputs, 1); 279 280 // build body subgraph 281 jlongArray body_output_handles_and_indices = buildSubgraph( 282 env, clazz, body_graph_builder, params.body_graph, params.body_inputs, 283 params.body_outputs, params.ninputs, params.ninputs); 284 285 if (cond_output_handles_and_indices == nullptr || 286 body_output_handles_and_indices == nullptr) 287 return nullptr; 288 289 // set cond_output param to output of the conditional subgraph 290 jlong* cond_output_elems = 291 env->GetLongArrayElements(cond_output_handles_and_indices, nullptr); 292 TF_Operation* cond_output_op = 293 requireOperationHandle(env, cond_output_elems[0]); 294 params.cond_output = {cond_output_op, 295 static_cast<jint>(cond_output_elems[1])}; 296 env->ReleaseLongArrayElements(cond_output_handles_and_indices, 297 cond_output_elems, 0); 298 299 // set body_outputs param to outputs of the body subgraph 300 jlong* body_output_elems = 301 env->GetLongArrayElements(body_output_handles_and_indices, nullptr); 302 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { 303 TF_Operation* body_output_op = 304 requireOperationHandle(env, body_output_elems[i]); 305 params.body_outputs[i] = {body_output_op, 306 static_cast<jint>(body_output_elems[j])}; 307 } 308 env->ReleaseLongArrayElements(body_output_handles_and_indices, 309 body_output_elems, 0); 310 311 // set loop name param 312 params.name = env->GetStringUTFChars(name, 0); 313 314 // build the while loop, storing loop outputs in `outputs` 315 std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]); 316 TF_FinishWhile(¶ms, status, outputs.get()); 317 318 throwExceptionIfNotOK(env, status); 319 TF_DeleteStatus(status); 320 321 env->ReleaseStringUTFChars(name, params.name); 322 323 // returned array contains both op handles and output indices, in pair 324 jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2); 325 jlong* output_elems = 326 env->GetLongArrayElements(output_handles_and_indices, nullptr); 327 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { 328 TF_Output output = outputs.get()[i]; 329 output_elems[i] = reinterpret_cast<jlong>(output.oper); 330 output_elems[j] = static_cast<jlong>(output.index); 331 } 332 env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0); 333 334 return output_handles_and_indices; 335 } 336