Home | History | Annotate | Download | only in native
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/java/src/main/native/tensor_jni.h"
     17 
     18 #include <assert.h>
     19 #include <stdlib.h>
     20 #include <string.h>
     21 #include <algorithm>
     22 #include <memory>
     23 
     24 #include "tensorflow/c/c_api.h"
     25 #include "tensorflow/java/src/main/native/exception_jni.h"
     26 
     27 namespace {
     28 
     29 TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
     30   if (handle == 0) {
     31     throwException(env, kNullPointerException,
     32                    "close() was called on the Tensor");
     33     return nullptr;
     34   }
     35   return reinterpret_cast<TF_Tensor*>(handle);
     36 }
     37 
     38 size_t elemByteSize(TF_DataType dtype) {
     39   // The code in this file makes the assumption that the
     40   // TensorFlow TF_DataTypes and the Java primitive types
     41   // have the same byte sizes. Validate that:
     42   switch (dtype) {
     43     case TF_BOOL:
     44     case TF_UINT8:
     45       static_assert(sizeof(jboolean) == 1,
     46                     "Java boolean not compatible with TF_BOOL");
     47       static_assert(sizeof(jbyte) == 1,
     48                     "Java byte not compatible with TF_UINT8");
     49       return 1;
     50     case TF_FLOAT:
     51     case TF_INT32:
     52       static_assert(sizeof(jfloat) == 4,
     53                     "Java float not compatible with TF_FLOAT");
     54       static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
     55       return 4;
     56     case TF_DOUBLE:
     57     case TF_INT64:
     58       static_assert(sizeof(jdouble) == 8,
     59                     "Java double not compatible with TF_DOUBLE");
     60       static_assert(sizeof(jlong) == 8,
     61                     "Java long not compatible with TF_INT64");
     62       return 8;
     63     default:
     64       return 0;
     65   }
     66 }
     67 
     68 // Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
     69 void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
     70                  size_t dst_size) {
     71   size_t sz = elemByteSize(dtype);
     72   if (sz != dst_size) {
     73     throwException(
     74         env, kIllegalStateException,
     75         "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
     76         dst_size);
     77     return;
     78   }
     79   switch (dtype) {
     80 // env->FindClass and env->GetMethodID are expensive and JNI best practices
     81 // suggest that they should be cached. However, until the creation of scalar
     82 // valued tensors seems to become a noticeable fraction of program execution,
     83 // ignore that cost.
     84 #define CASE(dtype, jtype, method_name, method_signature, call_type)           \
     85   case dtype: {                                                                \
     86     jclass clazz = env->FindClass("java/lang/Number");                         \
     87     jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
     88     jtype v = env->Call##call_type##Method(src, method);                       \
     89     memcpy(dst, &v, sz);                                                       \
     90     return;                                                                    \
     91   }
     92     CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
     93     CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
     94     CASE(TF_INT32, jint, "intValue", "()I", Int);
     95     CASE(TF_INT64, jlong, "longValue", "()J", Long);
     96     CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
     97 #undef CASE
     98     case TF_BOOL: {
     99       jclass clazz = env->FindClass("java/lang/Boolean");
    100       jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
    101       jboolean v = env->CallBooleanMethod(src, method);
    102       *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
    103       return;
    104     }
    105     default:
    106       throwException(env, kIllegalStateException, "invalid DataType(%d)",
    107                      dtype);
    108       return;
    109   }
    110 }
    111 
    112 // Copy a 1-D array of Java primitive types to the tensor buffer dst.
    113 // Returns the number of bytes written to dst.
    114 size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
    115                     size_t dst_size) {
    116   const int nelems = env->GetArrayLength(array);
    117   jboolean is_copy;
    118   switch (dtype) {
    119 #define CASE(dtype, jtype, get_type)                                   \
    120   case dtype: {                                                        \
    121     jtype##Array a = static_cast<jtype##Array>(array);                 \
    122     jtype* values = env->Get##get_type##ArrayElements(a, &is_copy);    \
    123     size_t to_copy = nelems * elemByteSize(dtype);                     \
    124     if (to_copy > dst_size) {                                          \
    125       throwException(                                                  \
    126           env, kIllegalStateException,                                 \
    127           "cannot write Java array of %d bytes to Tensor of %d bytes", \
    128           to_copy, dst_size);                                          \
    129       to_copy = 0;                                                     \
    130     } else {                                                           \
    131       memcpy(dst, values, to_copy);                                    \
    132     }                                                                  \
    133     env->Release##get_type##ArrayElements(a, values, JNI_ABORT);       \
    134     return to_copy;                                                    \
    135   }
    136     CASE(TF_FLOAT, jfloat, Float);
    137     CASE(TF_DOUBLE, jdouble, Double);
    138     CASE(TF_INT32, jint, Int);
    139     CASE(TF_INT64, jlong, Long);
    140     CASE(TF_BOOL, jboolean, Boolean);
    141     CASE(TF_UINT8, jbyte, Byte);
    142 #undef CASE
    143     default:
    144       throwException(env, kIllegalStateException, "invalid DataType(%d)",
    145                      dtype);
    146       return 0;
    147   }
    148 }
    149 
    150 // Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
    151 // Java primitive types. Returns the number of bytes read from src.
    152 size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
    153                    size_t src_size, jarray dst) {
    154   const int len = env->GetArrayLength(dst);
    155   const size_t sz = len * elemByteSize(dtype);
    156   if (sz > src_size) {
    157     throwException(
    158         env, kIllegalStateException,
    159         "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
    160         src_size);
    161     return 0;
    162   }
    163   switch (dtype) {
    164 #define CASE(dtype, jtype, primitive_type)                                 \
    165   case dtype: {                                                            \
    166     jtype##Array arr = static_cast<jtype##Array>(dst);                     \
    167     env->Set##primitive_type##ArrayRegion(arr, 0, len,                     \
    168                                           static_cast<const jtype*>(src)); \
    169     return sz;                                                             \
    170   }
    171     CASE(TF_FLOAT, jfloat, Float);
    172     CASE(TF_DOUBLE, jdouble, Double);
    173     CASE(TF_INT32, jint, Int);
    174     CASE(TF_INT64, jlong, Long);
    175     CASE(TF_BOOL, jboolean, Boolean);
    176     CASE(TF_UINT8, jbyte, Byte);
    177 #undef CASE
    178     default:
    179       throwException(env, kIllegalStateException, "invalid DataType(%d)",
    180                      dtype);
    181   }
    182   return 0;
    183 }
    184 
    185 size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
    186                     char* dst, size_t dst_size) {
    187   if (dims_left == 1) {
    188     return write1DArray(env, src, dtype, dst, dst_size);
    189   } else {
    190     jobjectArray ndarray = static_cast<jobjectArray>(src);
    191     int len = env->GetArrayLength(ndarray);
    192     size_t sz = 0;
    193     for (int i = 0; i < len; ++i) {
    194       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
    195       sz +=
    196           writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
    197       env->DeleteLocalRef(row);
    198       if (env->ExceptionCheck()) return sz;
    199     }
    200     return sz;
    201   }
    202 }
    203 
    204 size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
    205                    size_t src_size, int dims_left, jarray dst) {
    206   if (dims_left == 1) {
    207     return read1DArray(env, dtype, src, src_size, dst);
    208   } else {
    209     jobjectArray ndarray = static_cast<jobjectArray>(dst);
    210     int len = env->GetArrayLength(ndarray);
    211     size_t sz = 0;
    212     for (int i = 0; i < len; ++i) {
    213       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
    214       sz +=
    215           readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
    216       env->DeleteLocalRef(row);
    217       if (env->ExceptionCheck()) return sz;
    218     }
    219     return sz;
    220   }
    221 }
    222 
    223 jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
    224                                        size_t src_len, TF_Status* status) {
    225   const char* dst = nullptr;
    226   size_t dst_len = 0;
    227   TF_StringDecode(src, src_len, &dst, &dst_len, status);
    228   if (TF_GetCode(status) != TF_OK) {
    229     return nullptr;
    230   }
    231   jbyteArray ret = env->NewByteArray(dst_len);
    232   jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
    233   memcpy(cpy, dst, dst_len);
    234   env->ReleaseByteArrayElements(ret, cpy, 0);
    235   return ret;
    236 }
    237 
    238 class StringTensorWriter {
    239  public:
    240   StringTensorWriter(TF_Tensor* t, int num_elements)
    241       : offset_(0),
    242         poffsets_(static_cast<char*>(TF_TensorData(t))),
    243         pdata_(poffsets_ + 8 * num_elements),
    244         plimit_(poffsets_ + TF_TensorByteSize(t)) {}
    245 
    246   void Add(const char* src, size_t len, TF_Status* status) {
    247     if (TF_GetCode(status) != TF_OK) return;
    248     if (plimit_ - poffsets_ < sizeof(offset_)) {
    249       TF_SetStatus(status, TF_OUT_OF_RANGE,
    250                    "TF_STRING tensor encoding ran out of space for offsets, "
    251                    "this is likely a bug, please file an issue at "
    252                    "https://github.com/tensorflow/tensorflow/issues/new");
    253       return;
    254     }
    255     memcpy(poffsets_, &offset_, sizeof(offset_));
    256     size_t written =
    257         TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
    258     offset_ += written;
    259     poffsets_ += 8;
    260     pdata_ += written;
    261   }
    262 
    263  private:
    264   uint64_t offset_;
    265   char* poffsets_;
    266   char* pdata_;
    267   const char* plimit_;
    268 };
    269 
    270 class StringTensorReader {
    271  public:
    272   StringTensorReader(const TF_Tensor* t, int num_elements)
    273       : index_(0),
    274         offsets_(static_cast<const char*>(TF_TensorData(t))),
    275         data_(offsets_ + 8 * num_elements),
    276         limit_(offsets_ + TF_TensorByteSize(t)) {}
    277 
    278   jbyteArray Next(JNIEnv* env, TF_Status* status) {
    279     if (TF_GetCode(status) != TF_OK) return nullptr;
    280     uint64_t offset = 0;
    281     const char* poffset = offsets_ + sizeof(offset) * index_;
    282     if (poffset >= limit_) {
    283       TF_SetStatus(
    284           status, TF_INTERNAL,
    285           "Invalid TF_STRING tensor, offsets table seems to be too small");
    286       return nullptr;
    287     }
    288     memcpy(&offset, poffset, sizeof(offset));
    289     const char* pdata = data_ + offset;
    290     if (pdata >= limit_) {
    291       TF_SetStatus(status, TF_INTERNAL,
    292                    "Invalid TF_STRING tensor, invalid entry in offset table");
    293       return nullptr;
    294     }
    295     ++index_;
    296     return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
    297   }
    298 
    299  private:
    300   int index_;
    301   const char* offsets_;
    302   const char* data_;
    303   const char* limit_;
    304 };
    305 
    306 void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
    307                        jobjectArray dst, TF_Status* status) {
    308   jsize len = env->GetArrayLength(dst);
    309   if (dims_left == 1) {
    310     for (jsize i = 0; i < len; ++i) {
    311       jbyteArray elem = reader->Next(env, status);
    312       if (TF_GetCode(status) != TF_OK) return;
    313       env->SetObjectArrayElement(dst, i, elem);
    314     }
    315     return;
    316   }
    317   for (jsize i = 0; i < len; ++i) {
    318     jobjectArray arr =
    319         static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
    320     readNDStringArray(env, reader, dims_left - 1, arr, status);
    321     if (TF_GetCode(status) != TF_OK) return;
    322   }
    323 }
    324 }  // namespace
    325 
    326 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
    327                                                             jclass clazz,
    328                                                             jint dtype,
    329                                                             jlongArray shape,
    330                                                             jlong sizeInBytes) {
    331   int num_dims = static_cast<int>(env->GetArrayLength(shape));
    332   jlong* dims = nullptr;
    333   if (num_dims > 0) {
    334     jboolean is_copy;
    335     dims = env->GetLongArrayElements(shape, &is_copy);
    336   }
    337   static_assert(sizeof(jlong) == sizeof(int64_t),
    338                 "Java long is not compatible with the TensorFlow C API");
    339   // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
    340   //
    341   // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
    342   // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
    343   // *') is not allowed
    344   //
    345   // Since this array is typically very small, use the guaranteed safe scheme of
    346   // creating a copy.
    347   int64_t* dims_copy = new int64_t[num_dims];
    348   for (int i = 0; i < num_dims; ++i) {
    349     dims_copy[i] = static_cast<int64_t>(dims[i]);
    350   }
    351   TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
    352                                    num_dims, static_cast<size_t>(sizeInBytes));
    353   delete[] dims_copy;
    354   if (dims != nullptr) {
    355     env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
    356   }
    357   if (t == nullptr) {
    358     throwException(env, kNullPointerException,
    359                    "unable to allocate memory for the Tensor");
    360     return 0;
    361   }
    362   return reinterpret_cast<jlong>(t);
    363 }
    364 
    365 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
    366     JNIEnv* env, jclass clazz, jbyteArray value) {
    367   // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
    368   // TF_StringEncode-encoded bytes.
    369   size_t src_len = static_cast<int>(env->GetArrayLength(value));
    370   size_t dst_len = TF_StringEncodedSize(src_len);
    371   TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
    372   char* dst = static_cast<char*>(TF_TensorData(t));
    373   memset(dst, 0, 8);  // The offset table
    374 
    375   TF_Status* status = TF_NewStatus();
    376   jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
    377   // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
    378   // reinterpret_cast<> for this conversion should be safe.
    379   TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
    380                   dst_len, status);
    381   env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
    382   if (!throwExceptionIfNotOK(env, status)) {
    383     TF_DeleteStatus(status);
    384     return 0;
    385   }
    386   TF_DeleteStatus(status);
    387   return reinterpret_cast<jlong>(t);
    388 }
    389 
    390 namespace {
    391 size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
    392   if (num_dims == 0) {
    393     // This is the last dimension, i.e., value should correspond to a jbyteArray
    394     // encoding the string.
    395     return TF_StringEncodedSize(
    396         static_cast<size_t>(env->GetArrayLength(value)));
    397   }
    398   jsize len = env->GetArrayLength(value);
    399   size_t ret = 0;
    400   for (jsize i = 0; i < len; ++i) {
    401     jarray elem = static_cast<jarray>(
    402         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
    403     ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
    404   }
    405   return ret;
    406 }
    407 
    408 void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
    409                                       StringTensorWriter* writer,
    410                                       TF_Status* status) {
    411   if (num_dims == 0) {
    412     jbyte* jsrc =
    413         env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
    414     writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
    415                 status);
    416     env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
    417                                   JNI_ABORT);
    418     return;
    419   }
    420   jsize len = env->GetArrayLength(value);
    421   for (jsize i = 0; i < len; ++i) {
    422     jarray elem = static_cast<jarray>(
    423         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
    424     if (TF_GetCode(status) != TF_OK) return;
    425     fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
    426   }
    427 }
    428 }  // namespace
    429 
    430 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
    431     JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
    432   // TF_STRING tensors are encoded with a table of 8-byte offsets following by
    433   // TF_StringEncode-encoded bytes.
    434   const int num_dims = static_cast<int>(env->GetArrayLength(shape));
    435   int64_t* dims = new int64_t[num_dims];
    436   int64_t num_elements = 1;
    437   {
    438     jlong* jdims = env->GetLongArrayElements(shape, nullptr);
    439     for (int i = 0; i < num_dims; ++i) {
    440       dims[i] = static_cast<int64_t>(jdims[i]);
    441       num_elements *= dims[i];
    442     }
    443     env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
    444   }
    445   const size_t encoded_size =
    446       nonScalarTF_STRINGTensorSize(env, value, num_dims);
    447   TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
    448                                    8 * num_elements + encoded_size);
    449   if (t == nullptr) {
    450     delete[] dims;
    451     throwException(env, kNullPointerException,
    452                    "unable to allocate memory for the Tensor");
    453     return 0;
    454   }
    455   TF_Status* status = TF_NewStatus();
    456   StringTensorWriter writer(t, num_elements);
    457   fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
    458   delete[] dims;
    459   jlong ret = 0;
    460   if (!throwExceptionIfNotOK(env, status)) {
    461     TF_DeleteTensor(t);
    462   } else {
    463     ret = reinterpret_cast<jlong>(t);
    464   }
    465   TF_DeleteStatus(status);
    466   return ret;
    467 }
    468 
    469 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
    470                                                          jclass clazz,
    471                                                          jlong handle) {
    472   if (handle == 0) return;
    473   TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
    474 }
    475 
    476 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
    477                                                             jclass clazz,
    478                                                             jlong handle) {
    479   TF_Tensor* t = requireHandle(env, handle);
    480   if (t == nullptr) return nullptr;
    481   void* data = TF_TensorData(t);
    482   const size_t sz = TF_TensorByteSize(t);
    483 
    484   return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
    485 }
    486 
    487 JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
    488                                                         jclass clazz,
    489                                                         jlong handle) {
    490   static_assert(sizeof(jint) >= sizeof(TF_DataType),
    491                 "TF_DataType in C cannot be represented as an int in Java");
    492   TF_Tensor* t = requireHandle(env, handle);
    493   if (t == nullptr) return 0;
    494   return static_cast<jint>(TF_TensorType(t));
    495 }
    496 
    497 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
    498                                                               jclass clazz,
    499                                                               jlong handle) {
    500   TF_Tensor* t = requireHandle(env, handle);
    501   if (t == nullptr) return nullptr;
    502   static_assert(sizeof(jlong) == sizeof(int64_t),
    503                 "Java long is not compatible with the TensorFlow C API");
    504   const jsize num_dims = TF_NumDims(t);
    505   jlongArray ret = env->NewLongArray(num_dims);
    506   jlong* dims = env->GetLongArrayElements(ret, nullptr);
    507   for (int i = 0; i < num_dims; ++i) {
    508     dims[i] = static_cast<jlong>(TF_Dim(t, i));
    509   }
    510   env->ReleaseLongArrayElements(ret, dims, 0);
    511   return ret;
    512 }
    513 
    514 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
    515                                                            jclass clazz,
    516                                                            jlong handle,
    517                                                            jobject value) {
    518   TF_Tensor* t = requireHandle(env, handle);
    519   if (t == nullptr) return;
    520   int num_dims = TF_NumDims(t);
    521   TF_DataType dtype = TF_TensorType(t);
    522   void* data = TF_TensorData(t);
    523   const size_t sz = TF_TensorByteSize(t);
    524   if (num_dims == 0) {
    525     writeScalar(env, value, dtype, data, sz);
    526   } else {
    527     writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
    528                  static_cast<char*>(data), sz);
    529   }
    530 }
    531 
    532 #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix)                  \
    533   JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix(    \
    534       JNIEnv* env, jclass clazz, jlong handle) {                               \
    535     jtype ret = 0;                                                             \
    536     TF_Tensor* t = requireHandle(env, handle);                                 \
    537     if (t == nullptr) return ret;                                              \
    538     if (TF_NumDims(t) != 0) {                                                  \
    539       throwException(env, kIllegalStateException, "Tensor is not a scalar");   \
    540     } else if (TF_TensorType(t) != dtype) {                                    \
    541       throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
    542                      #method_suffix);                                          \
    543     } else {                                                                   \
    544       memcpy(&ret, TF_TensorData(t), elemByteSize(dtype));                     \
    545     }                                                                          \
    546     return ret;                                                                \
    547   }
    548 DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
    549 DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
    550 DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
    551 DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
    552 DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
    553 #undef DEFINE_GET_SCALAR_METHOD
    554 
    555 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
    556     JNIEnv* env, jclass clazz, jlong handle) {
    557   TF_Tensor* t = requireHandle(env, handle);
    558   if (t == nullptr) return nullptr;
    559   if (TF_NumDims(t) != 0) {
    560     throwException(env, kIllegalStateException, "Tensor is not a scalar");
    561     return nullptr;
    562   }
    563   if (TF_TensorType(t) != TF_STRING) {
    564     throwException(env, kIllegalArgumentException,
    565                    "Tensor is not a string/bytes scalar");
    566     return nullptr;
    567   }
    568   const char* data = static_cast<const char*>(TF_TensorData(t));
    569   const char* src = data + 8;
    570   size_t src_len = TF_TensorByteSize(t) - 8;
    571   uint64_t offset = 0;
    572   memcpy(&offset, data, sizeof(offset));
    573   if (offset >= src_len) {
    574     throwException(env, kIllegalArgumentException,
    575                    "invalid tensor encoding: bad offsets");
    576     return nullptr;
    577   }
    578   TF_Status* status = TF_NewStatus();
    579   jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
    580   throwExceptionIfNotOK(env, status);
    581   TF_DeleteStatus(status);
    582   return ret;
    583 }
    584 
    585 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
    586                                                               jclass clazz,
    587                                                               jlong handle,
    588                                                               jobject value) {
    589   TF_Tensor* t = requireHandle(env, handle);
    590   if (t == nullptr) return;
    591   int num_dims = TF_NumDims(t);
    592   TF_DataType dtype = TF_TensorType(t);
    593   const void* data = TF_TensorData(t);
    594   const size_t sz = TF_TensorByteSize(t);
    595   if (num_dims == 0) {
    596     throwException(env, kIllegalArgumentException,
    597                    "copyTo() is not meant for scalar Tensors, use the scalar "
    598                    "accessor (floatValue(), intValue() etc.) instead");
    599     return;
    600   }
    601   if (dtype == TF_STRING) {
    602     int64_t num_elements = 1;
    603     for (int i = 0; i < num_dims; ++i) {
    604       num_elements *= TF_Dim(t, i);
    605     }
    606     StringTensorReader reader(t, num_elements);
    607     TF_Status* status = TF_NewStatus();
    608     readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
    609                       status);
    610     throwExceptionIfNotOK(env, status);
    611     TF_DeleteStatus(status);
    612     return;
    613   }
    614   readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
    615               static_cast<jarray>(value));
    616 }
    617