Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <stdlib.h>
     17 
     18 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
     19 #include "tensorflow/python/lib/core/numpy.h"
     20 #include "tensorflow/python/lib/core/py_seq_tensor.h"
     21 #include "tensorflow/python/lib/core/safe_ptr.h"
     22 
     23 #include "tensorflow/python/eager/pywrap_tensor.h"
     24 #include "tensorflow/python/eager/pywrap_tfe.h"
     25 
     26 #include "tensorflow/c/c_api.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/python/lib/core/ndarray_tensor.h"
     29 
     30 namespace {
     31 
     32 TFE_Context* GetContext(PyObject* ctx) {
     33   TFE_Context* context =
     34       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
     35   if (context == nullptr) {
     36     PyErr_SetString(PyExc_TypeError,
     37                     tensorflow::strings::StrCat(
     38                         "Expecting a PyCapsule encoded context handle. Got ",
     39                         Py_TYPE(ctx)->tp_name)
     40                         .c_str());
     41   }
     42   return context;
     43 }
     44 
     45 // Convert a Python numpy.ndarray object to a TFE_TensorHandle.
     46 // The two may share underlying storage so changes to one may reflect in the
     47 // other.
     48 TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
     49   tensorflow::Tensor t;
     50   auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
     51   if (cppstatus.ok()) {
     52     return TFE_NewTensorHandle(t);
     53   } else {
     54     PyErr_SetString(PyExc_ValueError,
     55                     tensorflow::strings::StrCat(
     56                         "Failed to convert numpy ndarray to a Tensor (",
     57                         cppstatus.error_message(), ").")
     58                         .c_str());
     59     return nullptr;
     60   }
     61 }
     62 
     63 // Casts data referred to by `handle` from type `src_type_enum` to type
     64 // `dst_type_enum`.
     65 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
     66                             TF_DataType src_type_enum,
     67                             TF_DataType dst_type_enum, TF_Status* out_status) {
     68   if (ctx == nullptr) return nullptr;
     69   const char* op_name = "Cast";
     70   const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
     71   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
     72 #define RETURN_ERROR  \
     73   {                   \
     74     TFE_DeleteOp(op); \
     75     return nullptr;   \
     76   }
     77   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
     78   TFE_OpSetDevice(op, device_name, out_status);
     79   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
     80   TFE_OpAddInput(op, handle, out_status);
     81   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
     82   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
     83   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
     84   TFE_TensorHandle* output = nullptr;
     85   int num_outputs = 1;
     86   TFE_Execute(op, &output, &num_outputs, out_status);
     87   if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
     88       output == nullptr) {
     89     if (output != nullptr) {
     90       TFE_DeleteTensorHandle(output);
     91     }
     92     RETURN_ERROR
     93   }
     94   TFE_DeleteOp(op);
     95   return output;
     96 #undef RETURN_ERROR
     97 }
     98 
     99 TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
    100                                PyObject* dev) {
    101   const char* device = "";
    102   if (dev != nullptr && dev != Py_None) {
    103     device = PyBytes_AsString(dev);
    104 #if PY_MAJOR_VERSION >= 3
    105     if (device == nullptr) {
    106       PyErr_Clear();
    107       device = PyUnicode_AsUTF8(dev);
    108     }
    109 #endif
    110     if (device == nullptr) {
    111       PyErr_SetString(PyExc_TypeError,
    112                       "Error parsing device argument to CopyToDevice");
    113       return nullptr;
    114     }
    115   }
    116   TFE_Context* context = GetContext(ctx);
    117   if (context == nullptr) {  // PyErr already set by GetContext
    118     return nullptr;
    119   }
    120   auto status = tensorflow::make_safe(TF_NewStatus());
    121   TFE_TensorHandle* new_handle =
    122       TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
    123   if (TF_GetCode(status.get()) != TF_OK) {
    124     PyErr_SetString(
    125         PyExc_RuntimeError,
    126         tensorflow::strings::StrCat("Error copying tensor to device: ", device,
    127                                     ". ", TF_Message(status.get()))
    128             .c_str());
    129     return nullptr;
    130   }
    131   return new_handle;
    132 }
    133 
    134 // Helper function to convert `v` to an int and store it in `*out`. Returns true
    135 // on success, false otherwise.
    136 // Note that we assume that v is a python int (not long) representing a
    137 // TF_DataType value.
    138 bool PyIntToDataType(PyObject* v, int* out) {
    139 #if PY_MAJOR_VERSION < 3
    140   if (PyInt_Check(v)) {
    141     *out = PyInt_AS_LONG(v);
    142     return true;
    143   }
    144 #else
    145   if (PyLong_Check(v)) {
    146     *out = PyLong_AsLong(v);
    147     return true;
    148   }
    149 #endif
    150   return false;
    151 }
    152 
    153 // Helper function to create a python integer from TF_DataType.
    154 PyObject* PyIntFromDataType(TF_DataType l) {
    155 #if PY_MAJOR_VERSION < 3
    156   return PyInt_FromLong(l);
    157 #else
    158   return PyLong_FromLong(l);
    159 #endif
    160 }
    161 
    162 }  // namespace
    163 
    164 extern "C" {
    165 
    166 static const int kMaxEagerTensorParentSize = 32;
    167 
    168 // TODO(agarwal): store context handle in EagerTensor.
    169 typedef struct EagerTensor {
    170   PyObject_HEAD;
    171   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
    172   // parent class. The parent class is set at runtime, so we don't know the
    173   // exact size at compile time.
    174   char unused[kMaxEagerTensorParentSize];
    175   TFE_TensorHandle* handle;
    176   int64_t id;
    177   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
    178   // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
    179   // tensors, this will contain a serialized HandleData proto with shape
    180   // inference metadata about shapes and dtypes of resources accessible from
    181   // this handle.
    182   // Note that we assume that handle_data cannot participate in reference
    183   // cycles, and hence don't provide GC support for it.
    184   PyObject* handle_data;
    185 
    186   // This stores `_keras_mask` object and is set by Tensorflow layers.
    187   PyObject* keras_mask;
    188 } EagerTensor;
    189 
    190 // tp_init for EagerTensor.
    191 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
    192   self->id = get_uid();
    193   self->handle = nullptr;
    194   Py_INCREF(Py_None);
    195   self->handle_data = Py_None;
    196   Py_INCREF(Py_None);
    197   self->keras_mask = Py_None;
    198   PyObject* value;
    199   PyObject* context = nullptr;
    200   PyObject* device = nullptr;
    201   PyObject* dtype = Py_None;
    202   const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
    203   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
    204                                    const_cast<char**>(kwlist), &value, &context,
    205                                    &device, &dtype)) {
    206     return -1;
    207   }
    208   // Extract dtype
    209   int desired_dtype = -1;
    210   if (dtype != Py_None) {
    211     if (!PyIntToDataType(dtype, &desired_dtype)) {
    212       PyErr_SetString(PyExc_TypeError,
    213                       tensorflow::strings::StrCat(
    214                           "Expecting a DataType value for dtype. Got ",
    215                           Py_TYPE(dtype)->tp_name)
    216                           .c_str());
    217       return -1;
    218     }
    219   }
    220   tensorflow::Safe_TFE_TensorHandlePtr handle =
    221       tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr));
    222   PyErr_Clear();
    223   if (PyArray_Check(value)) {
    224     int desired_np_dtype = -1;
    225     if (desired_dtype >= 0) {
    226       if (!tensorflow::TF_DataType_to_PyArray_TYPE(
    227                static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
    228                .ok()) {
    229         PyErr_SetString(PyExc_TypeError,
    230                         tensorflow::strings::StrCat(
    231                             "Invalid dtype argument value ", desired_dtype)
    232                             .c_str());
    233         return -1;
    234       }
    235     }
    236     PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
    237     int current_np_dtype = PyArray_TYPE(array);
    238     auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
    239     if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
    240         !PyArray_ISCARRAY(array)) {
    241       int new_dtype =
    242           desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
    243       safe_value = tensorflow::make_safe(
    244           PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
    245                           NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
    246       if (PyErr_Occurred()) return -1;
    247       if (safe_value == nullptr) {
    248         PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
    249         return -1;
    250       }
    251       value = safe_value.get();
    252     }
    253     handle = tensorflow::make_safe(NumpyToTensorHandle(value));
    254   } else {
    255     tensorflow::Tensor t;
    256     // TODO(josh11b): Have PySeqToTensor set python errors instead of
    257     // returning Status.
    258     auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
    259     if (!cppstatus.ok()) {
    260       PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
    261       return -1;
    262     }
    263     handle = tensorflow::make_safe(TFE_NewTensorHandle(t));
    264   }
    265   if (PyErr_Occurred()) return -1;
    266   if (handle == nullptr) {
    267     PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor");
    268     return -1;
    269   }
    270   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
    271   if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
    272     auto out_status = tensorflow::make_safe(TF_NewStatus());
    273     handle = tensorflow::make_safe(
    274         EagerCast(GetContext(context), handle.get(), handle_dtype,
    275                   static_cast<TF_DataType>(desired_dtype), out_status.get()));
    276     if (TF_GetCode(out_status.get()) != TF_OK) {
    277       PyErr_SetString(
    278           PyExc_ValueError,
    279           tensorflow::strings::StrCat("Error while casting from DataType ",
    280                                       handle_dtype, " to ", desired_dtype, ". ",
    281                                       TF_Message(out_status.get()))
    282               .c_str());
    283       return -1;
    284     }
    285     handle_dtype = TFE_TensorHandleDataType(handle.get());
    286   }
    287 
    288   // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
    289   // memory. We approximate the same behavior for eager execution - keeping
    290   // int32 tensors in host memory.
    291   //
    292   // We do so to preclude the need for callers into such kernels from having to
    293   // explicitly place the int32 tensors in host memory. For example, without
    294   // this, one needed:
    295   //
    296   // with tf.device('/gpu:0'):
    297   //   ...// code here
    298   //   with tf.device('/cpu:0'):
    299   //     shape = tf.constant(...)
    300   //   y = tf.random_uniform(shape)
    301   //
    302   // Without the CPU device block, tfe.ops.random_uniform would fail since the
    303   // kernel expects the shape in host memory.
    304   //
    305   // With this support, we simplify the code:
    306   //
    307   // with tf.device('/gpu:0'):
    308   //   y = tf.random_uniform(...)
    309   //
    310   // The approximation is not exact there are GPU kernels which do not require
    311   // host memory for int32 tensors. This will lead to a discrepancy between
    312   // eager and graph execution.
    313   // TODO(ashankar): Fix this.
    314   if (handle_dtype != TF_INT32) {
    315     // Note that this is a shallow copy and will share the underlying buffer
    316     // if copying to the same device.
    317     handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
    318     if (handle == nullptr) return -1;
    319   }
    320   self->handle = handle.release();
    321   return 0;
    322 }
    323 
    324 // tp_dealloc for EagerTensor.
    325 void EagerTensor_dealloc(EagerTensor* self) {
    326   Py_DECREF(self->handle_data);
    327   Py_DECREF(self->keras_mask);
    328   TFE_DeleteTensorHandle(self->handle);
    329   self->handle = nullptr;
    330   // We have the global interpreter lock, so use this chance to perform delayed
    331   // refcount decrements.
    332   tensorflow::ClearDecrefCache();
    333   auto id = self->id;
    334   Py_TYPE(self)->tp_free(self);
    335   TFE_Py_TapeSetDeleteTrace(id);
    336 }
    337 
    338 // Getter for `_id`.
    339 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
    340   return PyLong_FromLongLong(self->id);
    341 }
    342 
    343 // Getter for `_datatype_enum`.
    344 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
    345   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
    346 }
    347 
    348 // Getter for `_shape_tuple`.
    349 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
    350   auto handle = self->handle;
    351   int n = TFE_TensorHandleNumDims(handle);
    352   PyObject* shape = PyTuple_New(n);
    353   if (PyErr_Occurred()) return nullptr;
    354   for (int i = 0; i < n; ++i) {
    355     PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i));
    356     if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
    357       Py_DECREF(shape);
    358       if (dim != nullptr) Py_DECREF(dim);
    359       PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
    360       return nullptr;
    361     }
    362   }
    363   return shape;
    364 }
    365 
    366 // Getter for `_rank`.
    367 static PyObject* EagerTensor_rank(EagerTensor* self) {
    368 #if PY_MAJOR_VERSION < 3
    369   return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle));
    370 #else
    371   return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle));
    372 #endif
    373 }
    374 
    375 static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
    376   Py_INCREF(self->handle_data);
    377   return self->handle_data;
    378 }
    379 
    380 static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
    381                                         void* unused) {
    382   Py_DECREF(self->handle_data);
    383   Py_INCREF(value);
    384   self->handle_data = value;
    385   return 0;
    386 }
    387 
    388 static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
    389   Py_INCREF(self->keras_mask);
    390   return self->keras_mask;
    391 }
    392 
    393 static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
    394                                      void* unused) {
    395   Py_DECREF(self->keras_mask);
    396   Py_INCREF(value);
    397   self->keras_mask = value;
    398   return 0;
    399 }
    400 // Function `_copy_to_device`.
    401 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
    402                                             PyObject* kwds) {
    403   const char* kwlist[] = {"context", "device", nullptr};
    404   PyObject* ctx = nullptr;
    405   PyObject* dev = nullptr;
    406   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
    407                                    &ctx, &dev) ||
    408       !ctx || !dev) {
    409     return nullptr;
    410   }
    411   auto handle = CopyToDevice(self->handle, ctx, dev);
    412   return EagerTensorFromHandle(handle);
    413 }
    414 
    415 // Function `_numpy`.
    416 // Convert an EagerTensor to a Python numpy.ndarray object.
    417 // The two may share underlying storage so changes to one may reflect in the
    418 // other.
    419 // Note that if `self` is not on CPU, we raise an Exception.
    420 static PyObject* EagerTensor_numpy(EagerTensor* self) {
    421   auto status = tensorflow::make_safe(TF_NewStatus());
    422   const tensorflow::Tensor* t =
    423       TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
    424   if (TF_GetCode(status.get()) != TF_OK) {
    425     PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
    426     return nullptr;
    427   }
    428   PyObject* ret = nullptr;
    429   auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
    430   if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
    431     Py_XDECREF(ret);
    432     return nullptr;
    433   } else {
    434     return ret;
    435   }
    436 }
    437 
    438 // Getter `device`.
    439 static PyObject* EagerTensor_device(EagerTensor* self) {
    440 #if PY_MAJOR_VERSION >= 3
    441   return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle));
    442 #else
    443   return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle));
    444 #endif
    445 }
    446 
    447 static PyGetSetDef EagerTensor_getseters[] = {
    448     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
    449      const_cast<char*>("_id"), nullptr},
    450     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
    451      const_cast<char*>("device"), nullptr},
    452     {const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
    453      (setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
    454      nullptr},
    455     {const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
    456      (setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
    457      nullptr},
    458     {nullptr} /* Sentinel */
    459 };
    460 
    461 static PyMethodDef EagerTensor_methods[] = {
    462     {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
    463      PyDoc_STR("_numpy")},
    464     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
    465      PyDoc_STR("_datatype_enum")},
    466     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
    467      PyDoc_STR("_shape_tuple")},
    468     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
    469     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
    470      METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
    471     {nullptr, nullptr},
    472 };
    473 
    474 // Note that here we are trying to dynamically create a new class as a subclass
    475 // of a "HEAPTYPE" class that is itself created in python code and passed in at
    476 // runtime. This is fairly atypical and undocumented.
    477 //
    478 // We use the following strategy for this. Unfortunately, we have to use
    479 // different approaches for python2.x vs python3.x
    480 // For python2.x, we create the class as a static type and set its tp_base to
    481 // the passed in type. Unfortunately setting tp_flags to include
    482 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
    483 // initialization of the underlying PyHeapTypeObject and not doing that leads to
    484 // some random crashes especially during garbage collection.
    485 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
    486 // However it provides a new function, PyType_FromSpecWithBases, to create
    487 // types dynamically.
    488 
    489 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
    490 PyTypeObject* EagerTensorType = nullptr;
    491 
    492 #if PY_MAJOR_VERSION >= 3
    493 static PyType_Slot EagerTensor_Type_slots[] = {
    494     Py_tp_dealloc,
    495     reinterpret_cast<void*>(EagerTensor_dealloc),
    496     Py_tp_methods,
    497     reinterpret_cast<void*>(EagerTensor_methods),
    498     Py_tp_getset,
    499     reinterpret_cast<void*>(EagerTensor_getseters),
    500     Py_tp_init,
    501     reinterpret_cast<void*>(EagerTensor_init),
    502     0,
    503     nullptr,
    504 };
    505 
    506 PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
    507                                      Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
    508                                      EagerTensor_Type_slots};
    509 #else
    510 // TODO(agarwal): support active_trace.
    511 static PyTypeObject _EagerTensorType = {
    512     // clang-format off
    513     PyVarObject_HEAD_INIT(nullptr, 0)
    514     // clang-format on
    515     "EagerTensor",                   /* tp_name */
    516     sizeof(EagerTensor),             /* tp_basicsize */
    517     0,                               /* tp_itemsize */
    518     (destructor)EagerTensor_dealloc, /* tp_dealloc */
    519     nullptr,                         /* tp_print */
    520     nullptr,                         /* tp_getattr */
    521     nullptr,                         /* tp_setattr */
    522     nullptr,                         /* tp_compare */
    523     nullptr,                         /* tp_repr */
    524     nullptr,                         /* tp_as_number */
    525     nullptr,                         /* tp_as_sequence */
    526     nullptr,                         /* tp_as_mapping */
    527     nullptr,                         /* tp_hash */
    528     nullptr,                         /* tp_call */
    529     nullptr,                         /* tp_str */
    530     nullptr,                         /* tp_getattro */
    531     nullptr,                         /* tp_setattro */
    532     nullptr,                         /* tp_as_buffer */
    533     Py_TPFLAGS_DEFAULT,              /* tp_flags */
    534     nullptr,                         /* tp_doc */
    535     nullptr,                         /* tp_traverse */
    536     nullptr,                         /* tp_clear */
    537     nullptr,                         /* tp_richcompare */
    538     0,                               /* tp_weaklistoffset */
    539     nullptr,                         /* tp_iter */
    540     nullptr,                         /* tp_iternext */
    541     EagerTensor_methods,             /* tp_methods */
    542     nullptr,                         /* tp_members */
    543     EagerTensor_getseters,           /* tp_getset */
    544     nullptr,                         /* tp_base */
    545     nullptr,                         /* tp_dict */
    546     nullptr,                         /* tp_descr_get */
    547     nullptr,                         /* tp_descr_set */
    548     0,                               /* tp_dictoffset */
    549     (initproc)EagerTensor_init,      /* tp_init */
    550     nullptr,                         /* tp_alloc */
    551     nullptr,                         /* tp_new */
    552 };
    553 
    554 #endif
    555 
    556 }  // extern "C"
    557 
    558 bool EagerTensor_CheckExact(const PyObject* o) {
    559   return Py_TYPE(o) == EagerTensorType;
    560 }
    561 
    562 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
    563   return reinterpret_cast<const EagerTensor*>(o)->handle;
    564 }
    565 
    566 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
    567   if (handle == nullptr) {
    568     return nullptr;
    569   }
    570   EagerTensor* t = reinterpret_cast<EagerTensor*>(
    571       EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
    572   if (t != nullptr) {
    573     t->id = get_uid();
    574     Py_INCREF(Py_None);
    575     t->handle_data = Py_None;
    576     Py_INCREF(Py_None);
    577     t->keras_mask = Py_None;
    578     t->handle = handle;
    579   }
    580   return reinterpret_cast<PyObject*>(t);
    581 }
    582 
    583 tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
    584   CHECK(EagerTensor_CheckExact(tensor));
    585   return reinterpret_cast<const EagerTensor*>(tensor)->id;
    586 }
    587 
    588 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
    589   if (!PyType_Check(base_class)) {
    590     PyErr_SetString(
    591         PyExc_TypeError,
    592         tensorflow::strings::StrCat(
    593             "Expecting a class definition for `base_class` passed to ",
    594             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
    595             .c_str());
    596     return nullptr;
    597   }
    598   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
    599   // EagerTensor to allow for the space usage of the base class.
    600   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
    601   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
    602     PyErr_SetString(
    603         PyExc_TypeError,
    604         tensorflow::strings::StrCat(
    605             "Unable to create subclass EagerTensor from base class ",
    606             Py_TYPE(base_class)->tp_name,
    607             ". Need its size to be <= ", kMaxEagerTensorParentSize)
    608             .c_str());
    609     return nullptr;
    610   }
    611   if (base_class_type->tp_itemsize != 0) {
    612     PyErr_SetString(
    613         PyExc_TypeError,
    614         tensorflow::strings::StrCat(
    615             "Unable to create subclass EagerTensor from base class ",
    616             Py_TYPE(base_class)->tp_name,
    617             " which supports variable length instances.")
    618             .c_str());
    619     return nullptr;
    620   }
    621   Py_INCREF(base_class);
    622 #if PY_MAJOR_VERSION >= 3
    623   PyObject* bases = PyTuple_New(1);
    624   PyTuple_SET_ITEM(bases, 0, base_class);
    625   EagerTensorType = reinterpret_cast<PyTypeObject*>(
    626       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
    627   if (PyErr_Occurred()) {
    628     return nullptr;
    629   }
    630   if (EagerTensorType == nullptr) {
    631     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
    632     return nullptr;
    633   }
    634 #else
    635   _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
    636 
    637   if (PyType_Ready(&_EagerTensorType) < 0) {
    638     if (PyErr_Occurred()) return nullptr;
    639     PyErr_SetString(PyExc_RuntimeError,
    640                     "Error while creating EagerTensor type.");
    641     return nullptr;
    642   }
    643   EagerTensorType = &_EagerTensorType;
    644   Py_INCREF(EagerTensorType);
    645 #endif
    646   // We disable instance based attribute lookup. Its not clear if these
    647   // dictionaries are correctly initialized in the first place.
    648   EagerTensorType->tp_dictoffset = 0;
    649   return reinterpret_cast<PyObject*>(EagerTensorType);
    650 }
    651 
    652 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) {
    653   if (!PyList_Check(tensor_list)) {
    654     PyErr_SetString(PyExc_TypeError,
    655                     tensorflow::strings::StrCat(
    656                         "tensor_list argument must be a list. Got \"",
    657                         Py_TYPE(tensor_list)->tp_name, "\"")
    658                         .c_str());
    659     return nullptr;
    660   }
    661   if (slice_dim < 0) {
    662     PyErr_SetString(
    663         PyExc_ValueError,
    664         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
    665                                     "Got ",
    666                                     slice_dim)
    667             .c_str());
    668     return nullptr;
    669   }
    670 
    671   Py_ssize_t num_tensors = PyList_Size(tensor_list);
    672   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
    673   auto tensor = tensorflow::make_safe(TF_AllocateTensor(
    674       TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
    675   int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
    676   for (Py_ssize_t i = 0; i < num_tensors; ++i) {
    677     PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i);
    678     if (!EagerTensor_CheckExact(tensor_obj)) {
    679       PyErr_SetString(PyExc_TypeError,
    680                       tensorflow::strings::StrCat(
    681                           "Expected a list of EagerTensors but "
    682                           "element ",
    683                           i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
    684                           .c_str());
    685       return nullptr;
    686     }
    687 
    688     EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
    689     TFE_TensorHandle* handle = t->handle;
    690     if (slice_dim >= TFE_TensorHandleNumDims(handle)) {
    691       PyErr_SetString(PyExc_IndexError,
    692                       tensorflow::strings::StrCat(
    693                           "Slice dimension (", slice_dim,
    694                           ") must be smaller than rank of all "
    695                           "tensors, but tensor at index ",
    696                           i, " has rank ", TFE_TensorHandleNumDims(handle))
    697                           .c_str());
    698       return nullptr;
    699     }
    700     int64_t dim = TFE_TensorHandleDim(handle, slice_dim);
    701     data[i] = dim;
    702   }
    703 
    704   auto status = tensorflow::make_safe(TF_NewStatus());
    705   TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
    706   if (TF_GetCode(status.get()) != TF_OK) {
    707     PyErr_SetString(
    708         PyExc_RuntimeError,
    709         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
    710                                     TF_Message(status.get()))
    711             .c_str());
    712     return nullptr;
    713   }
    714 
    715   return EagerTensorFromHandle(handle);
    716 }
    717