Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <stdlib.h>
     17 
     18 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
     19 #include "tensorflow/python/lib/core/numpy.h"
     20 #include "tensorflow/python/lib/core/py_seq_tensor.h"
     21 #include "tensorflow/python/lib/core/safe_ptr.h"
     22 
     23 #include "tensorflow/python/eager/pywrap_tensor.h"
     24 #include "tensorflow/python/eager/pywrap_tfe.h"
     25 
     26 #include "tensorflow/c/c_api.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/python/lib/core/ndarray_tensor.h"
     29 
     30 #include "tensorflow/core/framework/types.h"
     31 
     32 #include "structmember.h"  // NOLINT // For PyMemberDef
     33 
     34 // forward declare
     35 struct EagerTensor;
     36 
     37 namespace {
     38 
     39 // An instance of _EagerTensorProfiler that will receive callbacks about
     40 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
     41 PyObject* eager_tensor_profiler = nullptr;
     42 
     43 TFE_Context* GetContext(PyObject* ctx) {
     44   TFE_Context* context =
     45       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
     46   if (context == nullptr) {
     47     PyErr_SetString(PyExc_TypeError,
     48                     tensorflow::strings::StrCat(
     49                         "Expecting a PyCapsule encoded context handle. Got ",
     50                         Py_TYPE(ctx)->tp_name)
     51                         .c_str());
     52   }
     53   return context;
     54 }
     55 
     56 // Convert a Python numpy.ndarray object to a TFE_TensorHandle.
     57 // The two may share underlying storage so changes to one may reflect in the
     58 // other.
     59 TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
     60   tensorflow::Tensor t;
     61   auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
     62   if (cppstatus.ok()) {
     63     return TFE_NewTensorHandle(t);
     64   } else {
     65     PyErr_SetString(PyExc_ValueError,
     66                     tensorflow::strings::StrCat(
     67                         "Failed to convert numpy ndarray to a Tensor (",
     68                         cppstatus.error_message(), ").")
     69                         .c_str());
     70     return nullptr;
     71   }
     72 }
     73 
     74 TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
     75                                PyObject* dev) {
     76   const char* device = "";
     77   if (dev != nullptr && dev != Py_None) {
     78     device = PyBytes_AsString(dev);
     79 #if PY_MAJOR_VERSION >= 3
     80     if (device == nullptr) {
     81       PyErr_Clear();
     82       device = PyUnicode_AsUTF8(dev);
     83     }
     84 #endif
     85     if (device == nullptr) {
     86       PyErr_SetString(PyExc_TypeError,
     87                       "Error parsing device argument to CopyToDevice");
     88       return nullptr;
     89     }
     90   }
     91   TFE_Context* context = GetContext(ctx);
     92   if (context == nullptr) {  // PyErr already set by GetContext
     93     return nullptr;
     94   }
     95   auto status = tensorflow::make_safe(TF_NewStatus());
     96   TFE_TensorHandle* new_handle =
     97       TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
     98   if (TF_GetCode(status.get()) != TF_OK) {
     99     PyErr_SetString(
    100         PyExc_RuntimeError,
    101         tensorflow::strings::StrCat("Error copying tensor to device: ", device,
    102                                     ". ", TF_Message(status.get()))
    103             .c_str());
    104     return nullptr;
    105   }
    106   return new_handle;
    107 }
    108 
    109 // Helper function to convert `v` to an int and store it in `*out`. Returns true
    110 // on success, false otherwise.
    111 // Note that we assume that v is a python int (not long) representing a
    112 // TF_DataType value.
    113 bool PyIntToDataType(PyObject* v, int* out) {
    114 #if PY_MAJOR_VERSION < 3
    115   if (PyInt_Check(v)) {
    116     *out = PyInt_AS_LONG(v);
    117     return true;
    118   }
    119 #else
    120   if (PyLong_Check(v)) {
    121     *out = PyLong_AsLong(v);
    122     return true;
    123   }
    124 #endif
    125   return false;
    126 }
    127 
    128 // Helper function to create a python integer from TF_DataType.
    129 PyObject* PyIntFromDataType(TF_DataType l) {
    130 #if PY_MAJOR_VERSION < 3
    131   return PyInt_FromLong(l);
    132 #else
    133   return PyLong_FromLong(l);
    134 #endif
    135 }
    136 
    137 }  // namespace
    138 
    139 namespace tensorflow {
    140 // This function checks whether the desired type is "compatible" with the
    141 // inferred type. At a high level, compatibility means that all integral types
    142 // are compatible with each other, and all floating types are compatible with
    143 // each other.
    144 //
    145 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
    146 // compatible with int32). This is intended to match graph behavior.
    147 bool IsCompatible(int desired_dtype, TF_DataType returned_dtype) {
    148   tensorflow::DataType desired =
    149       static_cast<tensorflow::DataType>(desired_dtype);
    150   tensorflow::DataType returned =
    151       static_cast<tensorflow::DataType>(returned_dtype);
    152 
    153   if (desired == returned) return true;
    154 
    155   if (tensorflow::DataTypeIsInteger(desired) &&
    156       tensorflow::DataTypeIsInteger(returned)) {
    157     return true;
    158   } else if (tensorflow::DataTypeIsFloating(desired) &&
    159              (tensorflow::DataTypeIsFloating(returned) ||
    160               tensorflow::DataTypeIsInteger(returned))) {
    161     return true;
    162   } else if (tensorflow::DataTypeIsComplex(desired) &&
    163              (tensorflow::DataTypeIsComplex(returned) ||
    164               tensorflow::DataTypeIsInteger(returned) ||
    165               tensorflow::DataTypeIsFloating(returned))) {
    166     return true;
    167   } else if (tensorflow::DataTypeIsQuantized(desired) &&
    168              tensorflow::DataTypeIsInteger(returned)) {
    169     return true;
    170   }
    171   return false;
    172 }
    173 
    174 // Casts data referred to by `handle` from type `src_type_enum` to type
    175 // `dst_type_enum`.
    176 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
    177                             TF_DataType src_type_enum,
    178                             TF_DataType dst_type_enum, TF_Status* out_status) {
    179   if (ctx == nullptr) return nullptr;
    180   const char* op_name = "Cast";
    181   const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
    182   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
    183 #define RETURN_ERROR  \
    184   {                   \
    185     TFE_DeleteOp(op); \
    186     return nullptr;   \
    187   }
    188   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
    189   TFE_OpSetDevice(op, device_name, out_status);
    190   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
    191   TFE_OpAddInput(op, handle, out_status);
    192   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
    193   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
    194   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
    195   TFE_OpSetAttrBool(op, "Truncate", false);
    196   TFE_TensorHandle* output = nullptr;
    197   int num_outputs = 1;
    198   TFE_Execute(op, &output, &num_outputs, out_status);
    199   if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
    200       output == nullptr) {
    201     if (output != nullptr) {
    202       TFE_DeleteTensorHandle(output);
    203     }
    204     RETURN_ERROR
    205   }
    206   TFE_DeleteOp(op);
    207   return output;
    208 #undef RETURN_ERROR
    209 }
    210 
    211 TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
    212   int desired_dtype = -1;
    213   if (dtype != Py_None) {
    214     if (!PyIntToDataType(dtype, &desired_dtype)) {
    215       PyErr_SetString(PyExc_TypeError,
    216                       tensorflow::strings::StrCat(
    217                           "Expecting a DataType value for dtype. Got ",
    218                           Py_TYPE(dtype)->tp_name)
    219                           .c_str());
    220       return nullptr;
    221     }
    222   }
    223   tensorflow::Safe_PyObjectPtr value_decrefer;
    224   if (PyArray_IsScalar(value, Generic)) {
    225     // Convert numpy scalars to numpy arrays.
    226     value = PyArray_FromScalar(value, nullptr);
    227     // The returned value needs to be DECREF'd, but the original value was
    228     // created in python code, and doesn't need to be DECREF'd.
    229     value_decrefer.reset(value);
    230   }
    231   if (PyArray_Check(value)) {
    232     int desired_np_dtype = -1;
    233     if (desired_dtype >= 0) {
    234       if (!tensorflow::TF_DataType_to_PyArray_TYPE(
    235                static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
    236                .ok()) {
    237         PyErr_SetString(PyExc_TypeError,
    238                         tensorflow::strings::StrCat(
    239                             "Invalid dtype argument value ", desired_dtype)
    240                             .c_str());
    241         return nullptr;
    242       }
    243     }
    244     PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
    245     int current_np_dtype = PyArray_TYPE(array);
    246     auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
    247     if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
    248         !PyArray_ISCARRAY(array)) {
    249       int new_dtype =
    250           desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
    251       safe_value = tensorflow::make_safe(
    252           PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
    253                           NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
    254       if (PyErr_Occurred()) return nullptr;
    255       if (safe_value == nullptr) {
    256         PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
    257         return nullptr;
    258       }
    259       value = safe_value.get();
    260     }
    261     return NumpyToTensorHandle(value);
    262   } else {
    263     tensorflow::Tensor t;
    264     // TODO(josh11b): Have PySeqToTensor set python errors instead of
    265     // returning Status.
    266     auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
    267     if (!cppstatus.ok()) {
    268       PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
    269       return nullptr;
    270     }
    271     return TFE_NewTensorHandle(t);
    272   }
    273 }
    274 }  // namespace tensorflow
    275 
    276 extern "C" {
    277 
    278 static const int kMaxEagerTensorParentSize = 64;
    279 
    280 // TODO(agarwal): store context handle in EagerTensor.
    281 typedef struct EagerTensor {
    282   PyObject_HEAD;
    283   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
    284   // parent class. The parent class is set at runtime, so we don't know the
    285   // exact size at compile time.
    286   char unused[kMaxEagerTensorParentSize];
    287   TFE_TensorHandle* handle;
    288   int64_t id;
    289   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
    290   // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
    291   // tensors, this will contain a serialized HandleData proto with shape
    292   // inference metadata about shapes and dtypes of resources accessible from
    293   // this handle.
    294   // Note that we assume that handle_data cannot participate in reference
    295   // cycles, and hence don't provide GC support for it.
    296   PyObject* handle_data;
    297 
    298   // This stores `_keras_mask` object and is set by Tensorflow layers.
    299   PyObject* keras_mask;
    300 
    301   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
    302   // first time that `_EagerTensorBase`'s `shape` property is called.
    303   PyObject* tensor_shape;
    304 
    305   // We store a status object here as an optimization to avoid allocating a new
    306   // Status objects on different functions that operate on EagerTensor and need
    307   // to use a TF_Status object. However note that accesses to `status` are not
    308   // thread-safe.
    309   TF_Status* status;
    310 
    311   PyObject* weakreflist; /* List of weak references */
    312 
    313   // Per-instance attribute dictionary, to support monkey patching
    314   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
    315   // created by CPython the first time an attribute is assigned, pointed to by
    316   // tp_dictoffset. Note that garbage collection is not enabled for
    317   // EagerTensors, so assigning objects to EagerTensor attributes which require
    318   // garbage collection is likely to cause issues.
    319   PyObject* dict;
    320 } EagerTensor;
    321 
    322 namespace {
    323 
    324 // Returns true on success - successfully invoked or no profiler registered.
    325 // Returns false if some error occurred.
    326 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
    327   if (eager_tensor_profiler != nullptr) {
    328 #if PY_MAJOR_VERSION < 3
    329     PyObject* created_method_name = PyString_InternFromString("created");
    330 #else
    331     PyObject* created_method_name = PyUnicode_InternFromString("created");
    332 #endif
    333     if (created_method_name == nullptr) {
    334       return false;
    335     }
    336     PyObject* result = PyObject_CallMethodObjArgs(
    337         eager_tensor_profiler, created_method_name, created_tensor, NULL);
    338     if (result == nullptr) {
    339       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
    340       // While we can potentially continue because the error is related to
    341       // profiling, we choose to return an error because:
    342       //  - If profiling is used, the user likely wants to stop execution on
    343       //    profiling errors.
    344       //  - Error in profiling code might have left some state in an invalid
    345       //    form that can lead to an error later on. Better to fail fast.
    346       Py_DECREF(created_method_name);
    347       return false;
    348     }
    349     Py_DECREF(created_method_name);
    350     Py_DECREF(result);
    351   }
    352   return true;
    353 }
    354 
    355 }  // namespace
    356 
    357 // tp_init for EagerTensor.
    358 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
    359   self->id = get_uid();
    360   self->handle = nullptr;
    361   Py_INCREF(Py_None);
    362   self->handle_data = Py_None;
    363   Py_INCREF(Py_None);
    364   self->keras_mask = Py_None;
    365   Py_INCREF(Py_None);
    366   self->tensor_shape = Py_None;
    367   self->status = TF_NewStatus();
    368   self->dict = nullptr;
    369   self->weakreflist = nullptr;
    370   PyObject* value;
    371   PyObject* context = nullptr;
    372   PyObject* device = nullptr;
    373   PyObject* dtype = Py_None;
    374   PyObject* other_value = nullptr;
    375   const char* kwlist[] = {"value", "context",     "device",
    376                           "dtype", "other_value", nullptr};
    377   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
    378                                    const_cast<char**>(kwlist), &value, &context,
    379                                    &device, &dtype, &other_value)) {
    380     return -1;
    381   }
    382 
    383   if (other_value != nullptr) {
    384     if (!EagerTensor_CheckExact(other_value)) {
    385       PyErr_SetString(PyExc_TypeError,
    386                       tensorflow::strings::StrCat(
    387                           "Expecting an EagerTensor for other_value, got ",
    388                           Py_TYPE(other_value)->tp_name)
    389                           .c_str());
    390 
    391       return -1;
    392     }
    393     EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
    394     self->handle =
    395         TFE_TensorHandleCopySharingTensor(other->handle, self->status);
    396 
    397     if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    398       return -1;
    399     }
    400 
    401     return 0;
    402   }
    403 
    404   // Extract dtype
    405   int desired_dtype = -1;
    406   if (dtype != Py_None) {
    407     if (!PyIntToDataType(dtype, &desired_dtype)) {
    408       PyErr_SetString(PyExc_TypeError,
    409                       tensorflow::strings::StrCat(
    410                           "Expecting a DataType value for dtype. Got ",
    411                           Py_TYPE(dtype)->tp_name)
    412                           .c_str());
    413       return -1;
    414     }
    415   }
    416   PyErr_Clear();
    417   tensorflow::Safe_TFE_TensorHandlePtr handle =
    418       tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
    419           tensorflow::ConvertToEagerTensor(value, dtype)));
    420   if (handle == nullptr) return -1;
    421   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
    422   if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
    423     // Check type compatibility.
    424     if (tensorflow::IsCompatible(desired_dtype, handle_dtype)) {
    425       handle = tensorflow::make_safe(tensorflow::EagerCast(
    426           GetContext(context), handle.get(), handle_dtype,
    427           static_cast<TF_DataType>(desired_dtype), self->status));
    428       if (TF_GetCode(self->status) != TF_OK) {
    429         PyErr_SetString(
    430             PyExc_TypeError,
    431             tensorflow::strings::StrCat(
    432                 "Error while casting from DataType ",
    433                 tensorflow::DataTypeString(
    434                     static_cast<tensorflow::DataType>(handle_dtype)),
    435                 " to ",
    436                 tensorflow::DataTypeString(
    437                     static_cast<tensorflow::DataType>(desired_dtype)),
    438                 ". ", TF_Message(self->status))
    439                 .c_str());
    440         // Cleanup self->status before returning.
    441         TF_SetStatus(self->status, TF_OK, "");
    442         return -1;
    443       }
    444       handle_dtype = TFE_TensorHandleDataType(handle.get());
    445     } else {
    446       tensorflow::Safe_PyObjectPtr value_str(PyObject_Str(value));
    447       PyErr_SetString(
    448           PyExc_TypeError,
    449           tensorflow::strings::StrCat(
    450               "Cannot convert provided value to EagerTensor. Provided value: ",
    451               TFE_GetPythonString(value_str.get()), " Requested dtype: ",
    452               tensorflow::DataTypeString(
    453                   static_cast<tensorflow::DataType>(desired_dtype)))
    454               .c_str());
    455       return -1;
    456     }
    457   }
    458 
    459   // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
    460   // memory. We approximate the same behavior for eager execution - keeping
    461   // int32 tensors in host memory.
    462   //
    463   // We do so to preclude the need for callers into such kernels from having to
    464   // explicitly place the int32 tensors in host memory. For example, without
    465   // this, one needed:
    466   //
    467   // with tf.device('/gpu:0'):
    468   //   ...// code here
    469   //   with tf.device('/cpu:0'):
    470   //     shape = tf.constant(...)
    471   //   y = tf.random_uniform(shape)
    472   //
    473   // Without the CPU device block, tfe.ops.random_uniform would fail since the
    474   // kernel expects the shape in host memory.
    475   //
    476   // With this support, we simplify the code:
    477   //
    478   // with tf.device('/gpu:0'):
    479   //   y = tf.random_uniform(...)
    480   //
    481   // The approximation is not exact there are GPU kernels which do not require
    482   // host memory for int32 tensors. This will lead to a discrepancy between
    483   // eager and graph execution.
    484   // TODO(ashankar): Fix this.
    485   if (handle_dtype != TF_INT32) {
    486     // Note that this is a shallow copy and will share the underlying buffer
    487     // if copying to the same device.
    488     handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
    489     if (handle == nullptr) return -1;
    490   }
    491   self->handle = handle.release();
    492 
    493   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
    494     return -1;
    495   }
    496 
    497   return 0;
    498 }
    499 
    500 // tp_dealloc for EagerTensor.
    501 void EagerTensor_dealloc(EagerTensor* self) {
    502   // Clear weak references to self.
    503   // Needs to happen before any actual destruction.
    504   PyObject_ClearWeakRefs((PyObject*)self);
    505 
    506   TF_DeleteStatus(self->status);
    507   Py_DECREF(self->handle_data);
    508   Py_DECREF(self->keras_mask);
    509   Py_DECREF(self->tensor_shape);
    510   // If an attribute dictionary has been created, release it. Note that this
    511   // is only ever created by CPython's attribute setting methods; we don't
    512   // create it ourselves.
    513   Py_CLEAR(self->dict);
    514   if (self->handle != nullptr) {
    515     TFE_DeleteTensorHandle(self->handle);
    516     self->handle = nullptr;
    517   }
    518   // We have the global interpreter lock, so use this chance to perform delayed
    519   // refcount decrements.
    520   tensorflow::ClearDecrefCache();
    521   auto id = self->id;
    522   Py_TYPE(self)->tp_free(self);
    523   TFE_Py_TapeSetDeleteTrace(id);
    524 }
    525 
    526 // Getter for `_id`.
    527 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
    528   return PyLong_FromLongLong(self->id);
    529 }
    530 
    531 // Getter for `_datatype_enum`.
    532 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
    533   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
    534 }
    535 
    536 // Getter for `_shape_tuple`.
    537 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
    538   auto handle = self->handle;
    539   int n = TFE_TensorHandleNumDims(handle, self->status);
    540   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    541     // Cleanup self->status before returning.
    542     TF_SetStatus(self->status, TF_OK, "");
    543     return nullptr;
    544   }
    545   PyObject* shape = PyTuple_New(n);
    546   if (PyErr_Occurred()) return nullptr;
    547   for (int i = 0; i < n; ++i) {
    548     PyObject* dim =
    549         PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
    550     if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
    551         dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
    552       // Cleanup self->status before returning.
    553       TF_SetStatus(self->status, TF_OK, "");
    554       Py_DECREF(shape);
    555       if (dim != nullptr) Py_DECREF(dim);
    556       PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
    557       return nullptr;
    558     }
    559   }
    560   return shape;
    561 }
    562 
    563 // Getter for `_rank`.
    564 static PyObject* EagerTensor_rank(EagerTensor* self) {
    565   int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
    566   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    567     // Cleanup self->status before returning.
    568     TF_SetStatus(self->status, TF_OK, "");
    569     return nullptr;
    570   }
    571 #if PY_MAJOR_VERSION < 3
    572   return PyInt_FromLong(num_dims);
    573 #else
    574   return PyLong_FromLong(num_dims);
    575 #endif
    576 }
    577 
    578 // Getter for `_num_elements`.
    579 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
    580   auto handle = self->handle;
    581   int n = TFE_TensorHandleNumElements(handle, self->status);
    582   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    583     // Cleanup self->status before returning.
    584     TF_SetStatus(self->status, TF_OK, "");
    585     return nullptr;
    586   }
    587   return PyLong_FromLongLong(n);
    588 }
    589 
    590 static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
    591   Py_INCREF(self->handle_data);
    592   return self->handle_data;
    593 }
    594 
    595 static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
    596                                         void* unused) {
    597   Py_DECREF(self->handle_data);
    598   Py_INCREF(value);
    599   self->handle_data = value;
    600   return 0;
    601 }
    602 
    603 static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
    604   Py_INCREF(self->keras_mask);
    605   return self->keras_mask;
    606 }
    607 
    608 static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
    609                                      void* unused) {
    610   Py_DECREF(self->keras_mask);
    611   Py_INCREF(value);
    612   self->keras_mask = value;
    613   return 0;
    614 }
    615 
    616 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
    617   Py_INCREF(self->tensor_shape);
    618   return self->tensor_shape;
    619 }
    620 
    621 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
    622                                        void* unused) {
    623   Py_DECREF(self->tensor_shape);
    624   Py_INCREF(value);
    625   self->tensor_shape = value;
    626   return 0;
    627 }
    628 // Function `_copy_to_device`.
    629 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
    630                                             PyObject* kwds) {
    631   const char* kwlist[] = {"context", "device", nullptr};
    632   PyObject* ctx = nullptr;
    633   PyObject* dev = nullptr;
    634   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
    635                                    &ctx, &dev) ||
    636       !ctx || !dev) {
    637     return nullptr;
    638   }
    639   auto handle = CopyToDevice(self->handle, ctx, dev);
    640   return EagerTensorFromHandle(handle);
    641 }
    642 
    643 // Function `_numpy`.
    644 // Convert an EagerTensor to a Python numpy.ndarray object.
    645 // The two may share underlying storage so changes to one may reflect in the
    646 // other.
    647 // Note that if `self` is not on CPU, we raise an Exception.
    648 static PyObject* EagerTensor_numpy(EagerTensor* self) {
    649   auto status = tensorflow::make_safe(TF_NewStatus());
    650   const tensorflow::Tensor* t =
    651       TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
    652   if (TF_GetCode(status.get()) != TF_OK) {
    653     PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
    654     return nullptr;
    655   }
    656   PyObject* ret = nullptr;
    657   auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
    658   if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
    659     Py_XDECREF(ret);
    660     return nullptr;
    661   } else {
    662     return ret;
    663   }
    664 }
    665 
    666 // Getter `device`.
    667 static PyObject* EagerTensor_device(EagerTensor* self) {
    668   const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
    669   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    670     // Cleanup self->status before returning.
    671     TF_SetStatus(self->status, TF_OK, "");
    672     return nullptr;
    673   }
    674 #if PY_MAJOR_VERSION >= 3
    675   return PyUnicode_FromString(device);
    676 #else
    677   return PyBytes_FromString(device);
    678 #endif
    679 }
    680 
    681 // Getter `backing_device`.
    682 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
    683   const char* device =
    684       TFE_TensorHandleBackingDeviceName(self->handle, self->status);
    685   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
    686     // Cleanup self->status before returning.
    687     TF_SetStatus(self->status, TF_OK, "");
    688     return nullptr;
    689   }
    690 #if PY_MAJOR_VERSION >= 3
    691   return PyUnicode_FromString(device);
    692 #else
    693   return PyBytes_FromString(device);
    694 #endif
    695 }
    696 
    697 static PyGetSetDef EagerTensor_getseters[] = {
    698     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
    699      const_cast<char*>("_id"), nullptr},
    700     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
    701      const_cast<char*>("device"), nullptr},
    702     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
    703      nullptr, const_cast<char*>("backing_device"), nullptr},
    704     {const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
    705      (setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
    706      nullptr},
    707     {const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
    708      (setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
    709      nullptr},
    710     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
    711      (setter)EagerTensor_settensor_shape, const_cast<char*>("_tensor_shape"),
    712      nullptr},
    713     {nullptr} /* Sentinel */
    714 };
    715 
    716 #if PY_MAJOR_VERSION < 3
    717 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
    718 static PyMemberDef EagerTensor_members[] = {
    719     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
    720      READONLY},
    721     {nullptr},
    722 };
    723 #endif
    724 
    725 static PyMethodDef EagerTensor_methods[] = {
    726     {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
    727      PyDoc_STR("_numpy")},
    728     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
    729      PyDoc_STR("_datatype_enum")},
    730     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
    731      PyDoc_STR("_shape_tuple")},
    732     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
    733     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
    734      METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
    735     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
    736      PyDoc_STR("_num_elements")},
    737     {nullptr, nullptr},
    738 };
    739 
    740 // Note that here we are trying to dynamically create a new class as a subclass
    741 // of a "HEAPTYPE" class that is itself created in python code and passed in at
    742 // runtime. This is fairly atypical and undocumented.
    743 //
    744 // We use the following strategy for this. Unfortunately, we have to use
    745 // different approaches for python2.x vs python3.x
    746 // For python2.x, we create the class as a static type and set its tp_base to
    747 // the passed in type. Unfortunately setting tp_flags to include
    748 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
    749 // initialization of the underlying PyHeapTypeObject and not doing that leads to
    750 // some random crashes especially during garbage collection.
    751 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
    752 // However it provides a new function, PyType_FromSpecWithBases, to create
    753 // types dynamically.
    754 
    755 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
    756 PyTypeObject* EagerTensorType = nullptr;
    757 
    758 #if PY_MAJOR_VERSION >= 3
    759 static PyType_Slot EagerTensor_Type_slots[] = {
    760     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
    761     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
    762     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getseters)},
    763     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
    764     {0, nullptr},
    765 };
    766 #else
    767 // TODO(agarwal): support active_trace.
    768 static PyTypeObject _EagerTensorType = {
    769     // clang-format off
    770     PyVarObject_HEAD_INIT(nullptr, 0)
    771     // clang-format on
    772     "EagerTensor",                      /* tp_name */
    773     sizeof(EagerTensor),                /* tp_basicsize */
    774     0,                                  /* tp_itemsize */
    775     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
    776     nullptr,                            /* tp_print */
    777     nullptr,                            /* tp_getattr */
    778     nullptr,                            /* tp_setattr */
    779     nullptr,                            /* tp_compare */
    780     nullptr,                            /* tp_repr */
    781     nullptr,                            /* tp_as_number */
    782     nullptr,                            /* tp_as_sequence */
    783     nullptr,                            /* tp_as_mapping */
    784     nullptr,                            /* tp_hash */
    785     nullptr,                            /* tp_call */
    786     nullptr,                            /* tp_str */
    787     nullptr,                            /* tp_getattro */
    788     nullptr,                            /* tp_setattro */
    789     nullptr,                            /* tp_as_buffer */
    790     Py_TPFLAGS_DEFAULT,                 /* tp_flags */
    791     nullptr,                            /* tp_doc */
    792     nullptr,                            /* tp_traverse */
    793     nullptr,                            /* tp_clear */
    794     nullptr,                            /* tp_richcompare */
    795     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
    796     nullptr,                            /* tp_iter */
    797     nullptr,                            /* tp_iternext */
    798     EagerTensor_methods,                /* tp_methods */
    799     EagerTensor_members,                /* tp_members */
    800     EagerTensor_getseters,              /* tp_getset */
    801     nullptr,                            /* tp_base */
    802     nullptr,                            /* tp_dict */
    803     nullptr,                            /* tp_descr_get */
    804     nullptr,                            /* tp_descr_set */
    805     offsetof(EagerTensor, dict),        /* tp_dictoffset */
    806     (initproc)EagerTensor_init,         /* tp_init */
    807     nullptr,                            /* tp_alloc */
    808     nullptr,                            /* tp_new */
    809 };
    810 
    811 #endif
    812 
    813 }  // extern "C"
    814 
    815 bool EagerTensor_CheckExact(const PyObject* o) {
    816   return Py_TYPE(o) == EagerTensorType;
    817 }
    818 
    819 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
    820   return reinterpret_cast<const EagerTensor*>(o)->handle;
    821 }
    822 
    823 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
    824   if (handle == nullptr) {
    825     return nullptr;
    826   }
    827   EagerTensor* t = reinterpret_cast<EagerTensor*>(
    828       EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
    829   if (t != nullptr) {
    830     t->id = get_uid();
    831     Py_INCREF(Py_None);
    832     t->handle_data = Py_None;
    833     Py_INCREF(Py_None);
    834     t->keras_mask = Py_None;
    835     Py_INCREF(Py_None);
    836     t->tensor_shape = Py_None;
    837     t->handle = handle;
    838     t->status = TF_NewStatus();
    839     t->weakreflist = nullptr;
    840 
    841     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
    842       return nullptr;
    843     }
    844   }
    845   return reinterpret_cast<PyObject*>(t);
    846 }
    847 
    848 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
    849   DCHECK(EagerTensor_CheckExact(tensor));
    850   return reinterpret_cast<const EagerTensor*>(tensor)->id;
    851 }
    852 
    853 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
    854   DCHECK(EagerTensor_CheckExact(tensor));
    855   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
    856       reinterpret_cast<const EagerTensor*>(tensor)->handle));
    857 }
    858 
    859 tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
    860   DCHECK(EagerTensor_CheckExact(tensor));
    861   const EagerTensor* as_c_eager_tensor =
    862       reinterpret_cast<const EagerTensor*>(tensor);
    863   tensorflow::int64 result = TFE_TensorHandleNumElements(
    864       as_c_eager_tensor->handle, as_c_eager_tensor->status);
    865 
    866   if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
    867                                       PyExc_ValueError)) {
    868     // Cleanup status before returning.
    869     TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
    870     return -1;
    871   }
    872 
    873   return result;
    874 }
    875 
    876 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
    877   if (!PyType_Check(base_class)) {
    878     PyErr_SetString(
    879         PyExc_TypeError,
    880         tensorflow::strings::StrCat(
    881             "Expecting a class definition for `base_class` passed to ",
    882             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
    883             .c_str());
    884     return nullptr;
    885   }
    886   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
    887   // EagerTensor to allow for the space usage of the base class.
    888   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
    889   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
    890     PyErr_SetString(
    891         PyExc_TypeError,
    892         tensorflow::strings::StrCat(
    893             "Unable to create subclass EagerTensor from base class ",
    894             Py_TYPE(base_class)->tp_name,
    895             ". Need its size to be <= ", kMaxEagerTensorParentSize)
    896             .c_str());
    897     return nullptr;
    898   }
    899   if (base_class_type->tp_itemsize != 0) {
    900     PyErr_SetString(
    901         PyExc_TypeError,
    902         tensorflow::strings::StrCat(
    903             "Unable to create subclass EagerTensor from base class ",
    904             Py_TYPE(base_class)->tp_name,
    905             " which supports variable length instances.")
    906             .c_str());
    907     return nullptr;
    908   }
    909   Py_INCREF(base_class);
    910 #if PY_MAJOR_VERSION >= 3
    911   PyObject* bases = PyTuple_New(1);
    912   PyTuple_SET_ITEM(bases, 0, base_class);
    913 
    914   tensorflow::Safe_PyObjectPtr base_class_module(
    915       PyObject_GetAttrString(base_class, "__module__"));
    916   const char* module = nullptr;
    917   if (PyErr_Occurred()) {
    918     PyErr_Clear();
    919     module = "__builtin__";
    920   } else {
    921     module = PyBytes_AsString(base_class_module.get());
    922     if (module == nullptr) {
    923       PyErr_Clear();
    924       module = PyUnicode_AsUTF8(base_class_module.get());
    925       if (module == nullptr) {
    926         PyErr_Clear();
    927         module = "__builtin__";
    928       }
    929     }
    930   }
    931 
    932   // NOTE: The c_str from this string needs to outlast the function, hence is
    933   // static.
    934   static tensorflow::string fully_qualified_name =
    935       tensorflow::strings::StrCat(module, ".EagerTensor");
    936 
    937   static PyType_Spec EagerTensor_Type_spec = {
    938       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
    939       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
    940 
    941   EagerTensorType = reinterpret_cast<PyTypeObject*>(
    942       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
    943   if (PyErr_Occurred()) {
    944     return nullptr;
    945   }
    946   if (EagerTensorType == nullptr) {
    947     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
    948     return nullptr;
    949   }
    950   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
    951 #else
    952   _EagerTensorType.tp_base = base_class_type;
    953 
    954   if (PyType_Ready(&_EagerTensorType) < 0) {
    955     if (PyErr_Occurred()) return nullptr;
    956     PyErr_SetString(PyExc_RuntimeError,
    957                     "Error while creating EagerTensor type.");
    958     return nullptr;
    959   }
    960   EagerTensorType = &_EagerTensorType;
    961   Py_INCREF(EagerTensorType);
    962 #endif
    963   return reinterpret_cast<PyObject*>(EagerTensorType);
    964 }
    965 
    966 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
    967   Py_XDECREF(eager_tensor_profiler);
    968 
    969   if (profiler == Py_None) {
    970     eager_tensor_profiler = nullptr;
    971   } else {
    972     eager_tensor_profiler = profiler;
    973     Py_INCREF(eager_tensor_profiler);
    974   }
    975   Py_RETURN_NONE;
    976 }
    977 
    978 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
    979   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
    980     PyErr_SetString(PyExc_TypeError,
    981                     tensorflow::strings::StrCat(
    982                         "tensors argument must be a list or a tuple. Got \"",
    983                         Py_TYPE(tensors)->tp_name, "\"")
    984                         .c_str());
    985     return nullptr;
    986   }
    987   if (slice_dim < 0) {
    988     PyErr_SetString(
    989         PyExc_ValueError,
    990         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
    991                                     "Got ",
    992                                     slice_dim)
    993             .c_str());
    994     return nullptr;
    995   }
    996 
    997   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
    998   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
    999   auto tensor = tensorflow::make_safe(TF_AllocateTensor(
   1000       TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
   1001   int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
   1002   auto status = tensorflow::make_safe(TF_NewStatus());
   1003   for (Py_ssize_t i = 0; i < num_tensors; ++i) {
   1004     PyObject* tensor_obj = PySequence_Fast_GET_ITEM(tensors, i);
   1005     if (!EagerTensor_CheckExact(tensor_obj)) {
   1006       PyErr_SetString(PyExc_TypeError,
   1007                       tensorflow::strings::StrCat(
   1008                           "Expected a list of EagerTensors but "
   1009                           "element ",
   1010                           i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
   1011                           .c_str());
   1012       return nullptr;
   1013     }
   1014 
   1015     EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
   1016     TFE_TensorHandle* handle = t->handle;
   1017     int num_dims = TFE_TensorHandleNumDims(handle, status.get());
   1018     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
   1019       return nullptr;
   1020     }
   1021     if (slice_dim >= num_dims) {
   1022       PyErr_SetString(
   1023           PyExc_IndexError,
   1024           tensorflow::strings::StrCat("Slice dimension (", slice_dim,
   1025                                       ") must be smaller than rank of all "
   1026                                       "tensors, but tensor at index ",
   1027                                       i, " has rank ", num_dims)
   1028               .c_str());
   1029       return nullptr;
   1030     }
   1031     int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
   1032     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
   1033       return nullptr;
   1034     }
   1035     data[i] = dim;
   1036   }
   1037 
   1038   TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
   1039   if (TF_GetCode(status.get()) != TF_OK) {
   1040     PyErr_SetString(
   1041         PyExc_RuntimeError,
   1042         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
   1043                                     TF_Message(status.get()))
   1044             .c_str());
   1045     return nullptr;
   1046   }
   1047 
   1048   return EagerTensorFromHandle(handle);
   1049 }
   1050 
   1051 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
   1052   if (!EagerTensor_CheckExact(tensor)) {
   1053     PyErr_SetString(
   1054         PyExc_TypeError,
   1055         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
   1056                                     Py_TYPE(tensor)->tp_name, "\"")
   1057             .c_str());
   1058     return nullptr;
   1059   }
   1060   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
   1061 
   1062   auto status = tensorflow::make_safe(TF_NewStatus());
   1063   TFE_TensorDebugInfo* debug_info =
   1064       TFE_TensorHandleTensorDebugInfo(handle, status.get());
   1065   if (TF_GetCode(status.get()) != TF_OK) {
   1066     PyErr_SetString(
   1067         PyExc_RuntimeError,
   1068         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
   1069                                     TF_Message(status.get()))
   1070             .c_str());
   1071     return nullptr;
   1072   }
   1073 
   1074   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
   1075   PyObject* shape = PyTuple_New(rank);
   1076   for (int i = 0; i < rank; ++i) {
   1077     tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
   1078     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
   1079   }
   1080   TFE_DeleteTensorDebugInfo(debug_info);
   1081 
   1082   return shape;
   1083 }
   1084