Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstring>
     17 #include <thread>
     18 
     19 #include "tensorflow/python/eager/pywrap_tfe.h"
     20 
     21 #include "absl/strings/str_cat.h"
     22 #include "absl/types/variant.h"
     23 #include "tensorflow/c/c_api.h"
     24 #include "tensorflow/c/c_api_internal.h"
     25 #include "tensorflow/c/eager/c_api_internal.h"
     26 #include "tensorflow/c/eager/tape.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 #include "tensorflow/core/lib/gtl/cleanup.h"
     29 #include "tensorflow/core/lib/gtl/compactptrset.h"
     30 #include "tensorflow/core/lib/gtl/flatmap.h"
     31 #include "tensorflow/core/lib/gtl/flatset.h"
     32 #include "tensorflow/core/lib/strings/strcat.h"
     33 #include "tensorflow/core/lib/strings/stringprintf.h"
     34 #include "tensorflow/core/platform/mutex.h"
     35 #include "tensorflow/core/platform/protobuf.h"
     36 #include "tensorflow/core/platform/types.h"
     37 #include "tensorflow/python/eager/pywrap_tensor.h"
     38 #include "tensorflow/python/lib/core/safe_ptr.h"
     39 #include "tensorflow/python/util/util.h"
     40 
     41 using tensorflow::string;
     42 using tensorflow::strings::Printf;
     43 
     44 namespace {
     45 
     46 struct InputInfo {
     47   InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
     48 
     49   int i;
     50   bool is_list = false;
     51 };
     52 
     53 // Takes in output gradients, returns input gradients.
     54 typedef std::function<PyObject*(PyObject*)> PyBackwardFunction;
     55 
     56 using AttrToInputsMap =
     57     tensorflow::gtl::FlatMap<string,
     58                              tensorflow::gtl::InlinedVector<InputInfo, 4>>;
     59 
     60 tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED);
     61 tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
     62   static auto* all_attr_to_input_maps =
     63       new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
     64   return all_attr_to_input_maps;
     65 }
     66 
     67 AttrToInputsMap* GetAttrToInputsMap(const tensorflow::OpDef& op_def) {
     68   tensorflow::mutex_lock l(all_attr_to_input_maps_lock);
     69   auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
     70 
     71   auto* output =
     72       tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
     73   if (output != nullptr) {
     74     return output;
     75   }
     76 
     77   std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
     78 
     79   // Store a list of InputIndex -> List of corresponding inputs.
     80   for (int i = 0; i < op_def.input_arg_size(); i++) {
     81     if (!op_def.input_arg(i).type_attr().empty()) {
     82       auto it = m->find(op_def.input_arg(i).type_attr());
     83       if (it == m->end()) {
     84         it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
     85       }
     86       it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
     87     }
     88   }
     89 
     90   auto* retval = m.get();
     91   (*all_attr_to_input_maps)[op_def.name()] = m.release();
     92 
     93   return retval;
     94 }
     95 
     96 struct FastPathOpExecInfo {
     97   TFE_Context* ctx;
     98   const char* device_name;
     99   // The op def of the main op being executed.
    100   const tensorflow::OpDef* op_def;
    101 
    102   bool run_callbacks;
    103   bool run_post_exec_callbacks;
    104   bool run_gradient_callback;
    105 
    106   // The op name of the main op being executed.
    107   PyObject* name;
    108   // The op type name of the main op being executed.
    109   PyObject* op_name;
    110   PyObject* callbacks;
    111 
    112   // All the args passed into the FastPathOpExecInfo.
    113   PyObject* args;
    114 
    115   // DTypes can come from another input that has the same attr. So build that
    116   // map.
    117   const AttrToInputsMap* attr_to_inputs_map;
    118   tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
    119 };
    120 
    121 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
    122   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
    123                type* value) {                                                \
    124     if (check_fn(py_value)) {                                                \
    125       *value = static_cast<type>(parse_fn(py_value));                        \
    126       return true;                                                           \
    127     } else {                                                                 \
    128       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
    129                    tensorflow::strings::StrCat(                              \
    130                        "Expecting " #type " value for attr ", key, ", got ", \
    131                        py_value->ob_type->tp_name)                           \
    132                        .c_str());                                            \
    133       return false;                                                          \
    134     }                                                                        \
    135   }
    136 
    137 #if PY_MAJOR_VERSION >= 3
    138 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
    139 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
    140 #else
    141 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
    142 #endif
    143 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
    144 #undef PARSE_VALUE
    145 
    146 #if PY_MAJOR_VERSION < 3
    147 bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
    148                      int64_t* value) {
    149   if (PyInt_Check(py_value)) {
    150     *value = static_cast<int64_t>(PyInt_AsLong(py_value));
    151     return true;
    152   } else if (PyLong_Check(py_value)) {
    153     *value = static_cast<int64_t>(PyLong_AsLong(py_value));
    154     return true;
    155   }
    156   TF_SetStatus(
    157       status, TF_INVALID_ARGUMENT,
    158       tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
    159                                   ", got ", py_value->ob_type->tp_name)
    160           .c_str());
    161   return false;
    162 }
    163 #endif
    164 
    165 Py_ssize_t TensorShapeNumDims(PyObject* value) {
    166   const auto size = PySequence_Size(value);
    167   if (size == -1) {
    168     // TensorShape.__len__ raises an error in the scenario where the shape is an
    169     // unknown, which needs to be cleared.
    170     // TODO(nareshmodi): ensure that this is actually a TensorShape.
    171     PyErr_Clear();
    172   }
    173   return size;
    174 }
    175 
    176 bool IsInteger(PyObject* py_value) {
    177 #if PY_MAJOR_VERSION >= 3
    178   return PyLong_Check(py_value);
    179 #else
    180   return PyInt_Check(py_value);
    181 #endif
    182 }
    183 
    184 // This function considers a Dimension._value of None to be valid, and sets the
    185 // value to be -1 in that case.
    186 bool ParseDimensionValue(const string& key, PyObject* py_value,
    187                          TF_Status* status, int64_t* value) {
    188   if (IsInteger(py_value)) {
    189     return ParseInt64Value(key, py_value, status, value);
    190   }
    191 
    192   tensorflow::Safe_PyObjectPtr dimension_value(
    193       PyObject_GetAttrString(py_value, "_value"));
    194   if (dimension_value == nullptr) {
    195     TF_SetStatus(
    196         status, TF_INVALID_ARGUMENT,
    197         tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
    198                                     ", got ", py_value->ob_type->tp_name)
    199             .c_str());
    200     return false;
    201   }
    202 
    203   if (dimension_value.get() == Py_None) {
    204     *value = -1;
    205     return true;
    206   }
    207 
    208   return ParseInt64Value(key, dimension_value.get(), status, value);
    209 }
    210 
    211 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
    212                       tensorflow::StringPiece* value) {
    213   if (PyBytes_Check(py_value)) {
    214     Py_ssize_t size = 0;
    215     char* buf = nullptr;
    216     if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
    217     *value = tensorflow::StringPiece(buf, size);
    218     return true;
    219   }
    220 #if PY_MAJOR_VERSION >= 3
    221   if (PyUnicode_Check(py_value)) {
    222     Py_ssize_t size = 0;
    223     const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
    224     if (buf == nullptr) return false;
    225     *value = tensorflow::StringPiece(buf, size);
    226     return true;
    227   }
    228 #endif
    229   TF_SetStatus(
    230       status, TF_INVALID_ARGUMENT,
    231       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
    232                                   ", got ", py_value->ob_type->tp_name)
    233           .c_str());
    234   return false;
    235 }
    236 
    237 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
    238                     unsigned char* value) {
    239   *value = PyObject_IsTrue(py_value);
    240   return true;
    241 }
    242 
    243 // The passed in py_value is expected to be an object of the python type
    244 // dtypes.DType or an int.
    245 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
    246                     int* value) {
    247   if (IsInteger(py_value)) {
    248     return ParseIntValue(key, py_value, status, value);
    249   }
    250 
    251   tensorflow::Safe_PyObjectPtr py_type_enum(
    252       PyObject_GetAttrString(py_value, "_type_enum"));
    253   if (py_type_enum == nullptr) {
    254     PyErr_Clear();
    255     TF_SetStatus(
    256         status, TF_INVALID_ARGUMENT,
    257         tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
    258                                     ", got ", py_value->ob_type->tp_name)
    259             .c_str());
    260     return false;
    261   }
    262 
    263   return ParseIntValue(key, py_type_enum.get(), status, value);
    264 }
    265 
    266 bool SetOpAttrList(
    267     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
    268     TF_AttrType type,
    269     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    270     TF_Status* status) {
    271   if (!PySequence_Check(py_list)) {
    272     TF_SetStatus(
    273         status, TF_INVALID_ARGUMENT,
    274         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
    275                                     ", got ", py_list->ob_type->tp_name)
    276             .c_str());
    277     return false;
    278   }
    279   const int num_values = PySequence_Size(py_list);
    280   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
    281 
    282 #define PARSE_LIST(c_type, parse_fn)                                      \
    283   std::unique_ptr<c_type[]> values(new c_type[num_values]);               \
    284   for (int i = 0; i < num_values; ++i) {                                  \
    285     tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));   \
    286     if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
    287   }
    288 
    289   if (type == TF_ATTR_STRING) {
    290     std::unique_ptr<const void*[]> values(new const void*[num_values]);
    291     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
    292     for (int i = 0; i < num_values; ++i) {
    293       tensorflow::StringPiece value;
    294       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
    295       if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
    296       values[i] = value.data();
    297       lengths[i] = value.size();
    298     }
    299     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
    300   } else if (type == TF_ATTR_INT) {
    301     PARSE_LIST(int64_t, ParseInt64Value);
    302     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
    303   } else if (type == TF_ATTR_FLOAT) {
    304     PARSE_LIST(float, ParseFloatValue);
    305     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
    306   } else if (type == TF_ATTR_BOOL) {
    307     PARSE_LIST(unsigned char, ParseBoolValue);
    308     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
    309   } else if (type == TF_ATTR_TYPE) {
    310     PARSE_LIST(int, ParseTypeValue);
    311     TFE_OpSetAttrTypeList(op, key,
    312                           reinterpret_cast<const TF_DataType*>(values.get()),
    313                           num_values);
    314   } else if (type == TF_ATTR_SHAPE) {
    315     // Make one pass through the input counting the total number of
    316     // dims across all the input lists.
    317     int total_dims = 0;
    318     for (int i = 0; i < num_values; ++i) {
    319       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
    320       if (py_value.get() != Py_None) {
    321         if (!PySequence_Check(py_value.get())) {
    322           TF_SetStatus(
    323               status, TF_INVALID_ARGUMENT,
    324               tensorflow::strings::StrCat(
    325                   "Expecting None or sequence value for element", i,
    326                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
    327                   .c_str());
    328           return false;
    329         }
    330         const auto size = TensorShapeNumDims(py_value.get());
    331         if (size >= 0) {
    332           total_dims += size;
    333         }
    334       }
    335     }
    336     // Allocate a buffer that can fit all of the dims together.
    337     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
    338     // Copy the input dims into the buffer and set dims to point to
    339     // the start of each list's dims.
    340     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
    341     std::unique_ptr<int[]> num_dims(new int[num_values]);
    342     int64_t* offset = buffer.get();
    343     for (int i = 0; i < num_values; ++i) {
    344       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
    345       if (py_value.get() == Py_None) {
    346         dims[i] = nullptr;
    347         num_dims[i] = -1;
    348       } else {
    349         const auto size = TensorShapeNumDims(py_value.get());
    350         if (size == -1) {
    351           dims[i] = nullptr;
    352           num_dims[i] = -1;
    353           continue;
    354         }
    355         dims[i] = offset;
    356         num_dims[i] = size;
    357         for (int j = 0; j < size; ++j) {
    358           tensorflow::Safe_PyObjectPtr inner_py_value(
    359               PySequence_ITEM(py_value.get(), j));
    360           if (inner_py_value.get() == Py_None) {
    361             *offset = -1;
    362           } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
    363                                           offset)) {
    364             return false;
    365           }
    366           ++offset;
    367         }
    368       }
    369     }
    370     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
    371                            status);
    372     if (TF_GetCode(status) != TF_OK) return false;
    373   } else if (type == TF_ATTR_FUNC) {
    374     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
    375     for (int i = 0; i < num_values; ++i) {
    376       tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
    377       // Allow:
    378       // (1) String function name, OR
    379       // (2) A Python object with a .name attribute
    380       //     (A crude test for being a
    381       //     tensorflow.python.framework.function._DefinedFunction)
    382       //     (which is what the various "defun" or "Defun" decorators do).
    383       // And in the future also allow an object that can encapsulate
    384       // the function name and its attribute values.
    385       tensorflow::StringPiece func_name;
    386       if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
    387         PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
    388         if (name_attr == nullptr ||
    389             !ParseStringValue(key, name_attr, status, &func_name)) {
    390           TF_SetStatus(
    391               status, TF_INVALID_ARGUMENT,
    392               tensorflow::strings::StrCat(
    393                   "unable to set function value attribute from a ",
    394                   py_value.get()->ob_type->tp_name,
    395                   " object. If you think this is an error, please file an "
    396                   "issue at "
    397                   "https://github.com/tensorflow/tensorflow/issues/new")
    398                   .c_str());
    399           return false;
    400         }
    401       }
    402       funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
    403       if (TF_GetCode(status) != TF_OK) return false;
    404     }
    405     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
    406     if (TF_GetCode(status) != TF_OK) return false;
    407   } else {
    408     TF_SetStatus(status, TF_UNIMPLEMENTED,
    409                  tensorflow::strings::StrCat("Attr ", key,
    410                                              " has unhandled list type ", type)
    411                      .c_str());
    412     return false;
    413   }
    414 #undef PARSE_LIST
    415   return true;
    416 }
    417 
    418 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
    419                 TF_Status* status) {
    420   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
    421   for (const auto& attr : func.attr()) {
    422     if (TF_GetCode(status) != TF_OK) return nullptr;
    423     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
    424     if (TF_GetCode(status) != TF_OK) return nullptr;
    425   }
    426   return func_op;
    427 }
    428 
    429 void SetOpAttrListDefault(
    430     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
    431     const char* key, TF_AttrType type,
    432     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    433     TF_Status* status) {
    434   if (type == TF_ATTR_STRING) {
    435     int num_values = attr.default_value().list().s_size();
    436     std::unique_ptr<const void*[]> values(new const void*[num_values]);
    437     std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
    438     (*attr_list_sizes)[key] = num_values;
    439     for (int i = 0; i < num_values; i++) {
    440       const string& v = attr.default_value().list().s(i);
    441       values[i] = v.data();
    442       lengths[i] = v.size();
    443     }
    444     TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
    445   } else if (type == TF_ATTR_INT) {
    446     int num_values = attr.default_value().list().i_size();
    447     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
    448     (*attr_list_sizes)[key] = num_values;
    449     for (int i = 0; i < num_values; i++) {
    450       values[i] = attr.default_value().list().i(i);
    451     }
    452     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
    453   } else if (type == TF_ATTR_FLOAT) {
    454     int num_values = attr.default_value().list().f_size();
    455     std::unique_ptr<float[]> values(new float[num_values]);
    456     (*attr_list_sizes)[key] = num_values;
    457     for (int i = 0; i < num_values; i++) {
    458       values[i] = attr.default_value().list().f(i);
    459     }
    460     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
    461   } else if (type == TF_ATTR_BOOL) {
    462     int num_values = attr.default_value().list().b_size();
    463     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
    464     (*attr_list_sizes)[key] = num_values;
    465     for (int i = 0; i < num_values; i++) {
    466       values[i] = attr.default_value().list().b(i);
    467     }
    468     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
    469   } else if (type == TF_ATTR_TYPE) {
    470     int num_values = attr.default_value().list().type_size();
    471     std::unique_ptr<int[]> values(new int[num_values]);
    472     (*attr_list_sizes)[key] = num_values;
    473     for (int i = 0; i < num_values; i++) {
    474       values[i] = attr.default_value().list().type(i);
    475     }
    476     TFE_OpSetAttrTypeList(op, key,
    477                           reinterpret_cast<const TF_DataType*>(values.get()),
    478                           attr.default_value().list().type_size());
    479   } else if (type == TF_ATTR_SHAPE) {
    480     int num_values = attr.default_value().list().shape_size();
    481     (*attr_list_sizes)[key] = num_values;
    482     int total_dims = 0;
    483     for (int i = 0; i < num_values; ++i) {
    484       if (!attr.default_value().list().shape(i).unknown_rank()) {
    485         total_dims += attr.default_value().list().shape(i).dim_size();
    486       }
    487     }
    488     // Allocate a buffer that can fit all of the dims together.
    489     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
    490     // Copy the input dims into the buffer and set dims to point to
    491     // the start of each list's dims.
    492     std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
    493     std::unique_ptr<int[]> num_dims(new int[num_values]);
    494     int64_t* offset = buffer.get();
    495     for (int i = 0; i < num_values; ++i) {
    496       const auto& shape = attr.default_value().list().shape(i);
    497       if (shape.unknown_rank()) {
    498         dims[i] = nullptr;
    499         num_dims[i] = -1;
    500       } else {
    501         for (int j = 0; j < shape.dim_size(); j++) {
    502           *offset = shape.dim(j).size();
    503           ++offset;
    504         }
    505       }
    506     }
    507     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
    508                            status);
    509   } else if (type == TF_ATTR_FUNC) {
    510     int num_values = attr.default_value().list().func_size();
    511     (*attr_list_sizes)[key] = num_values;
    512     std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
    513     for (int i = 0; i < num_values; i++) {
    514       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
    515     }
    516     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
    517   } else {
    518     TF_SetStatus(status, TF_UNIMPLEMENTED,
    519                  "Lists of tensors are not yet implemented for default valued "
    520                  "attributes for an operation.");
    521   }
    522 }
    523 
    524 bool SetOpAttrScalar(
    525     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
    526     TF_AttrType type,
    527     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    528     TF_Status* status) {
    529   if (type == TF_ATTR_STRING) {
    530     tensorflow::StringPiece value;
    531     if (!ParseStringValue(key, py_value, status, &value)) return false;
    532     TFE_OpSetAttrString(op, key, value.data(), value.size());
    533   } else if (type == TF_ATTR_INT) {
    534     int64_t value;
    535     if (!ParseInt64Value(key, py_value, status, &value)) return false;
    536     TFE_OpSetAttrInt(op, key, value);
    537     // attr_list_sizes is set for all int attributes (since at this point we are
    538     // not aware if that attribute might be used to calculate the size of an
    539     // output list or not).
    540     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
    541   } else if (type == TF_ATTR_FLOAT) {
    542     float value;
    543     if (!ParseFloatValue(key, py_value, status, &value)) return false;
    544     TFE_OpSetAttrFloat(op, key, value);
    545   } else if (type == TF_ATTR_BOOL) {
    546     unsigned char value;
    547     if (!ParseBoolValue(key, py_value, status, &value)) return false;
    548     TFE_OpSetAttrBool(op, key, value);
    549   } else if (type == TF_ATTR_TYPE) {
    550     int value;
    551     if (!ParseTypeValue(key, py_value, status, &value)) return false;
    552     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
    553   } else if (type == TF_ATTR_SHAPE) {
    554     if (py_value == Py_None) {
    555       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
    556     } else {
    557       if (!PySequence_Check(py_value)) {
    558         TF_SetStatus(status, TF_INVALID_ARGUMENT,
    559                      tensorflow::strings::StrCat(
    560                          "Expecting None or sequence value for attr", key,
    561                          ", got ", py_value->ob_type->tp_name)
    562                          .c_str());
    563         return false;
    564       }
    565       const auto num_dims = TensorShapeNumDims(py_value);
    566       if (num_dims == -1) {
    567         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
    568         return true;
    569       }
    570       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
    571       for (int i = 0; i < num_dims; ++i) {
    572         tensorflow::Safe_PyObjectPtr inner_py_value(
    573             PySequence_ITEM(py_value, i));
    574         if (inner_py_value.get() == Py_None) {
    575           dims[i] = -1;
    576         } else if (!ParseDimensionValue(key, inner_py_value.get(), status,
    577                                         &dims[i])) {
    578           return false;
    579         }
    580       }
    581       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
    582     }
    583     if (TF_GetCode(status) != TF_OK) return false;
    584   } else if (type == TF_ATTR_FUNC) {
    585     // Allow:
    586     // (1) String function name, OR
    587     // (2) A Python object with a .name attribute
    588     //     (A crude test for being a
    589     //     tensorflow.python.framework.function._DefinedFunction)
    590     //     (which is what the various "defun" or "Defun" decorators do).
    591     // And in the future also allow an object that can encapsulate
    592     // the function name and its attribute values.
    593     tensorflow::StringPiece func_name;
    594     if (!ParseStringValue(key, py_value, status, &func_name)) {
    595       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
    596       if (name_attr == nullptr ||
    597           !ParseStringValue(key, name_attr, status, &func_name)) {
    598         TF_SetStatus(
    599             status, TF_INVALID_ARGUMENT,
    600             tensorflow::strings::StrCat(
    601                 "unable to set function value attribute from a ",
    602                 py_value->ob_type->tp_name,
    603                 " object. If you think this is an error, please file an issue "
    604                 "at https://github.com/tensorflow/tensorflow/issues/new")
    605                 .c_str());
    606         return false;
    607       }
    608     }
    609     TF_SetStatus(status, TF_OK, "");
    610     TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
    611   } else {
    612     TF_SetStatus(
    613         status, TF_UNIMPLEMENTED,
    614         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
    615             .c_str());
    616     return false;
    617   }
    618   return true;
    619 }
    620 
    621 void SetOpAttrScalarDefault(
    622     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
    623     const char* attr_name,
    624     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    625     TF_Status* status) {
    626   SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
    627   if (default_value.value_case() == tensorflow::AttrValue::kI) {
    628     (*attr_list_sizes)[attr_name] = default_value.i();
    629   }
    630 }
    631 
    632 // start_index is the index at which the Tuple/List attrs will start getting
    633 // processed.
    634 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
    635                 TF_Status* out_status) {
    636   if (attrs == Py_None) return;
    637   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
    638   if ((len & 1) != 0) {
    639     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    640                  "Expecting attrs tuple to have even length.");
    641     return;
    642   }
    643   // Parse attrs
    644   for (Py_ssize_t i = 0; i < len; i += 2) {
    645     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
    646     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
    647 #if PY_MAJOR_VERSION >= 3
    648     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
    649                                             : PyUnicode_AsUTF8(py_key);
    650 #else
    651     const char* key = PyBytes_AsString(py_key);
    652 #endif
    653     unsigned char is_list = 0;
    654     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
    655     if (TF_GetCode(out_status) != TF_OK) return;
    656     if (is_list != 0) {
    657       if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
    658         return;
    659     } else {
    660       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
    661         return;
    662     }
    663   }
    664 }
    665 
    666 // This function will set the op attrs required. If an attr has the value of
    667 // None, then it will read the AttrDef to get the default value and set that
    668 // instead. Any failure in this function will simply fall back to the slow
    669 // path.
    670 void SetOpAttrWithDefaults(
    671     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
    672     const char* attr_name, PyObject* attr_value,
    673     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    674     TF_Status* status) {
    675   unsigned char is_list = 0;
    676   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
    677   if (TF_GetCode(status) != TF_OK) return;
    678   if (attr_value == Py_None) {
    679     if (is_list != 0) {
    680       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
    681                            status);
    682     } else {
    683       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
    684                              attr_list_sizes, status);
    685     }
    686   } else {
    687     if (is_list != 0) {
    688       SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
    689                     status);
    690     } else {
    691       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
    692                       status);
    693     }
    694   }
    695 }
    696 
    697 // Python subclass of Exception that is created on not ok Status.
    698 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
    699 PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
    700 
    701 // Python subclass of Exception that is created to signal fallback.
    702 PyObject* fallback_exception_class = nullptr;
    703 
    704 // Python function that returns input gradients given output gradients.
    705 PyObject* gradient_function = nullptr;
    706 
    707 PyTypeObject* resource_variable_type = nullptr;
    708 
    709 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
    710 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
    711 
    712 }  // namespace
    713 
    714 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
    715                     const char* op_name, TFE_InputTensorHandles* inputs,
    716                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
    717                     TF_Status* out_status) {
    718   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
    719   if (TF_GetCode(out_status) != TF_OK) return;
    720   TFE_OpSetDevice(op, device_name, out_status);
    721   if (TF_GetCode(out_status) == TF_OK) {
    722     for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
    723          ++i) {
    724       TFE_OpAddInput(op, inputs->at(i), out_status);
    725     }
    726   }
    727   if (TF_GetCode(out_status) == TF_OK) {
    728     SetOpAttrs(ctx, op, attrs, 0, out_status);
    729   }
    730   Py_BEGIN_ALLOW_THREADS;
    731   if (TF_GetCode(out_status) == TF_OK) {
    732     int num_outputs = outputs->size();
    733     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
    734     outputs->resize(num_outputs);
    735   }
    736   if (TF_GetCode(out_status) != TF_OK) {
    737     TF_SetStatus(out_status, TF_GetCode(out_status),
    738                  tensorflow::strings::StrCat(TF_Message(out_status),
    739                                              " [Op:", op_name, "]")
    740                      .c_str());
    741   }
    742   TFE_DeleteOp(op);
    743   Py_END_ALLOW_THREADS;
    744 }
    745 
    746 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
    747   tensorflow::mutex_lock l(exception_class_mutex);
    748   if (exception_class != nullptr) {
    749     Py_DECREF(exception_class);
    750   }
    751   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
    752     exception_class = nullptr;
    753     PyErr_SetString(PyExc_TypeError,
    754                     "TFE_Py_RegisterExceptionClass: "
    755                     "Registered class should be subclass of Exception.");
    756     return nullptr;
    757   }
    758 
    759   Py_INCREF(e);
    760   exception_class = e;
    761   Py_RETURN_NONE;
    762 }
    763 
    764 PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e) {
    765   if (!PyType_Check(e)) {
    766     PyErr_SetString(
    767         PyExc_TypeError,
    768         "TFE_Py_RegisterResourceVariableType: Need to register a type.");
    769     return nullptr;
    770   }
    771 
    772   if (resource_variable_type != nullptr) {
    773     Py_DECREF(resource_variable_type);
    774   }
    775 
    776   Py_INCREF(e);
    777   resource_variable_type = reinterpret_cast<PyTypeObject*>(e);
    778   Py_RETURN_NONE;
    779 }
    780 
    781 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
    782   if (fallback_exception_class != nullptr) {
    783     Py_DECREF(fallback_exception_class);
    784   }
    785   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
    786     fallback_exception_class = nullptr;
    787     PyErr_SetString(PyExc_TypeError,
    788                     "TFE_Py_RegisterFallbackExceptionClass: "
    789                     "Registered class should be subclass of Exception.");
    790     return nullptr;
    791   } else {
    792     Py_INCREF(e);
    793     fallback_exception_class = e;
    794     Py_RETURN_NONE;
    795   }
    796 }
    797 
    798 PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
    799   if (gradient_function != nullptr) {
    800     Py_DECREF(gradient_function);
    801   }
    802   if (!PyCallable_Check(e)) {
    803     gradient_function = nullptr;
    804     PyErr_SetString(PyExc_TypeError,
    805                     "TFE_Py_RegisterBackwardFunctionGetter: "
    806                     "Registered object should be function.");
    807     return nullptr;
    808   } else {
    809     Py_INCREF(e);
    810     gradient_function = e;
    811     Py_RETURN_NONE;
    812   }
    813 }
    814 
    815 void RaiseFallbackException(const char* message) {
    816   if (fallback_exception_class != nullptr) {
    817     PyErr_SetString(fallback_exception_class, message);
    818     return;
    819   }
    820 
    821   PyErr_SetString(
    822       PyExc_RuntimeError,
    823       tensorflow::strings::StrCat(
    824           "Fallback exception type not set, attempting to fallback due to ",
    825           message)
    826           .data());
    827 }
    828 
    829 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
    830   if (TF_GetCode(status) == TF_OK) return 0;
    831   const char* msg = TF_Message(status);
    832   if (exception == nullptr) {
    833     tensorflow::mutex_lock l(exception_class_mutex);
    834     if (exception_class != nullptr) {
    835       tensorflow::Safe_PyObjectPtr val(
    836           Py_BuildValue("si", msg, TF_GetCode(status)));
    837       if (PyErr_Occurred()) {
    838         // NOTE: This hides the actual error (i.e. the reason `status` was not
    839         // TF_OK), but there is nothing we can do at this point since we can't
    840         // generate a reasonable error from the status.
    841         // Consider adding a message explaining this.
    842         return -1;
    843       }
    844       PyErr_SetObject(exception_class, val.get());
    845       return -1;
    846     } else {
    847       exception = PyExc_RuntimeError;
    848     }
    849   }
    850   // May be update already set exception.
    851   PyErr_SetString(exception, msg);
    852   return -1;
    853 }
    854 
    855 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
    856                                   PyObject* exception) {
    857   if (status.ok()) return 0;
    858   const char* msg = status.error_message().c_str();
    859   if (exception == nullptr) {
    860     tensorflow::mutex_lock l(exception_class_mutex);
    861     if (exception_class != nullptr) {
    862       tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
    863       PyErr_SetObject(exception_class, val.get());
    864       return -1;
    865     } else {
    866       exception = PyExc_RuntimeError;
    867     }
    868   }
    869   // May be update already set exception.
    870   PyErr_SetString(exception, msg);
    871   return -1;
    872 }
    873 
    874 const char* TFE_GetPythonString(PyObject* o) {
    875 #if PY_MAJOR_VERSION >= 3
    876   if (PyBytes_Check(o)) {
    877     return PyBytes_AsString(o);
    878   } else {
    879     return PyUnicode_AsUTF8(o);
    880   }
    881 #else
    882   return PyBytes_AsString(o);
    883 #endif
    884 }
    885 
    886 int64_t get_uid() {
    887   tensorflow::mutex_lock l(_uid_mutex);
    888   return _uid++;
    889 }
    890 
    891 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
    892 
    893 void TFE_DeleteContextCapsule(PyObject* context) {
    894   TFE_Context* ctx =
    895       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
    896   TFE_DeleteContext(ctx);
    897 }
    898 
    899 static tensorflow::int64 MakeInt(PyObject* integer) {
    900 #if PY_MAJOR_VERSION >= 3
    901   return PyLong_AsLong(integer);
    902 #else
    903   return PyInt_AsLong(integer);
    904 #endif
    905 }
    906 
    907 static tensorflow::int64 FastTensorId(PyObject* tensor) {
    908   if (EagerTensor_CheckExact(tensor)) {
    909     return PyEagerTensor_ID(tensor);
    910   }
    911   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
    912   if (id_field == nullptr) {
    913     return -1;
    914   }
    915   tensorflow::int64 id = MakeInt(id_field);
    916   Py_DECREF(id_field);
    917   return id;
    918 }
    919 
    920 static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
    921   if (EagerTensor_CheckExact(tensor)) {
    922     return PyEagerTensor_Dtype(tensor);
    923   }
    924   PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
    925   if (dtype_field == nullptr) {
    926     return tensorflow::DT_INVALID;
    927   }
    928   PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
    929   Py_DECREF(dtype_field);
    930   if (dtype_field == nullptr) {
    931     return tensorflow::DT_INVALID;
    932   }
    933   tensorflow::int64 id = MakeInt(enum_field);
    934   Py_DECREF(enum_field);
    935   return static_cast<tensorflow::DataType>(id);
    936 }
    937 
    938 class PyTapeTensor {
    939  public:
    940   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
    941                const tensorflow::TensorShape& shape)
    942       : id_(id), dtype_(dtype), shape_(shape) {}
    943   PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
    944                PyObject* shape)
    945       : id_(id), dtype_(dtype), shape_(shape) {
    946     Py_INCREF(absl::get<1>(shape_));
    947   }
    948   PyTapeTensor(const PyTapeTensor& other) {
    949     id_ = other.id_;
    950     dtype_ = other.dtype_;
    951     shape_ = other.shape_;
    952     if (shape_.index() == 1) {
    953       Py_INCREF(absl::get<1>(shape_));
    954     }
    955   }
    956 
    957   ~PyTapeTensor() {
    958     if (shape_.index() == 1) {
    959       Py_DECREF(absl::get<1>(shape_));
    960     }
    961   }
    962   PyObject* GetShape() const;
    963   PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
    964   tensorflow::int64 GetID() const { return id_; }
    965 
    966  private:
    967   tensorflow::int64 id_;
    968   tensorflow::DataType dtype_;
    969   absl::variant<tensorflow::TensorShape, PyObject*> shape_;
    970 };
    971 
    972 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
    973                                                   PyTapeTensor> {
    974  public:
    975   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
    976     Py_INCREF(py_vspace_);
    977   }
    978 
    979   tensorflow::Status Initialize() {
    980     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
    981     if (num_elements_ == nullptr) {
    982       return tensorflow::errors::InvalidArgument("invalid vspace");
    983     }
    984     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
    985     if (aggregate_fn_ == nullptr) {
    986       return tensorflow::errors::InvalidArgument("invalid vspace");
    987     }
    988     zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
    989     if (zeros_fn_ == nullptr) {
    990       return tensorflow::errors::InvalidArgument("invalid vspace");
    991     }
    992     ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
    993     if (ones_fn_ == nullptr) {
    994       return tensorflow::errors::InvalidArgument("invalid vspace");
    995     }
    996     graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
    997     if (graph_shape_fn_ == nullptr) {
    998       return tensorflow::errors::InvalidArgument("invalid vspace");
    999     }
   1000     return tensorflow::Status::OK();
   1001   }
   1002 
   1003   ~PyVSpace() override {
   1004     Py_XDECREF(num_elements_);
   1005     Py_XDECREF(aggregate_fn_);
   1006     Py_XDECREF(zeros_fn_);
   1007     Py_XDECREF(ones_fn_);
   1008     Py_XDECREF(graph_shape_fn_);
   1009 
   1010     Py_DECREF(py_vspace_);
   1011   }
   1012 
   1013   tensorflow::int64 NumElements(PyObject* tensor) const final {
   1014     if (EagerTensor_CheckExact(tensor)) {
   1015       return PyEagerTensor_NumElements(tensor);
   1016     }
   1017     PyObject* arglist =
   1018         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
   1019     PyObject* result = PyEval_CallObject(num_elements_, arglist);
   1020     Py_DECREF(arglist);
   1021     if (result == nullptr) {
   1022       // The caller detects whether a python exception has been raised.
   1023       return -1;
   1024     }
   1025     tensorflow::int64 r = MakeInt(result);
   1026     Py_DECREF(result);
   1027     return r;
   1028   }
   1029 
   1030   PyObject* AggregateGradients(
   1031       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
   1032     PyObject* list = PyList_New(gradient_tensors.size());
   1033     for (int i = 0; i < gradient_tensors.size(); ++i) {
   1034       // Note: stealing a reference to the gradient tensors.
   1035       CHECK(gradient_tensors[i] != nullptr);
   1036       CHECK(gradient_tensors[i] != Py_None);
   1037       PyList_SET_ITEM(list, i,
   1038                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
   1039     }
   1040     PyObject* arglist = Py_BuildValue("(O)", list);
   1041     CHECK(arglist != nullptr);
   1042     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
   1043     Py_DECREF(arglist);
   1044     Py_DECREF(list);
   1045     return result;
   1046   }
   1047 
   1048   void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
   1049 
   1050   PyObject* Zeros(const PyTapeTensor& tensor) const final {
   1051     if (PyErr_Occurred()) {
   1052       return nullptr;
   1053     }
   1054     PyObject* py_shape = tensor.GetShape();
   1055     if (PyErr_Occurred()) {
   1056       return nullptr;
   1057     }
   1058     PyObject* py_dtype = tensor.GetDType();
   1059     if (PyErr_Occurred()) {
   1060       Py_DECREF(py_shape);
   1061       return nullptr;
   1062     }
   1063     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
   1064     PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
   1065     Py_DECREF(arg_list);
   1066     Py_DECREF(py_dtype);
   1067     Py_DECREF(py_shape);
   1068     return reinterpret_cast<PyObject*>(result);
   1069   }
   1070 
   1071   PyObject* Ones(const PyTapeTensor& tensor) const final {
   1072     if (PyErr_Occurred()) {
   1073       return nullptr;
   1074     }
   1075     PyObject* py_shape = tensor.GetShape();
   1076     PyObject* py_dtype = tensor.GetDType();
   1077     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
   1078     PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
   1079     Py_DECREF(arg_list);
   1080     Py_DECREF(py_dtype);
   1081     Py_DECREF(py_shape);
   1082     return result;
   1083   }
   1084 
   1085   PyObject* GraphShape(PyObject* tensor) const {
   1086     PyObject* arg_list = Py_BuildValue("(O)", tensor);
   1087     PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
   1088     Py_DECREF(arg_list);
   1089     return result;
   1090   }
   1091 
   1092   tensorflow::Status CallBackwardFunction(
   1093       PyBackwardFunction* backward_function,
   1094       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
   1095       std::vector<PyObject*>* result) const final {
   1096     PyObject* grads = PyTuple_New(output_gradients.size());
   1097     for (int i = 0; i < output_gradients.size(); ++i) {
   1098       if (output_gradients[i] == nullptr) {
   1099         Py_INCREF(Py_None);
   1100         PyTuple_SET_ITEM(grads, i, Py_None);
   1101       } else {
   1102         PyTuple_SET_ITEM(grads, i,
   1103                          reinterpret_cast<PyObject*>(output_gradients[i]));
   1104       }
   1105     }
   1106     PyObject* py_result = (*backward_function)(grads);
   1107     Py_DECREF(grads);
   1108     if (py_result == nullptr) {
   1109       return tensorflow::errors::Internal("gradient function threw exceptions");
   1110     }
   1111     result->clear();
   1112     PyObject* seq =
   1113         PySequence_Fast(py_result, "expected a sequence of gradients");
   1114     if (seq == nullptr) {
   1115       return tensorflow::errors::InvalidArgument(
   1116           "gradient function did not return a list");
   1117     }
   1118     int len = PySequence_Fast_GET_SIZE(seq);
   1119     VLOG(1) << "Gradient length is " << len;
   1120     result->reserve(len);
   1121     for (int i = 0; i < len; ++i) {
   1122       PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
   1123       if (item == Py_None) {
   1124         result->push_back(nullptr);
   1125       } else {
   1126         Py_INCREF(item);
   1127         result->push_back(item);
   1128       }
   1129     }
   1130     Py_DECREF(seq);
   1131     Py_DECREF(py_result);
   1132     return tensorflow::Status::OK();
   1133   }
   1134 
   1135   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
   1136 
   1137  private:
   1138   PyObject* py_vspace_;
   1139 
   1140   PyObject* num_elements_;
   1141   PyObject* aggregate_fn_;
   1142   PyObject* zeros_fn_;
   1143   PyObject* ones_fn_;
   1144   PyObject* graph_shape_fn_;
   1145 };
   1146 PyVSpace* py_vspace = nullptr;
   1147 
   1148 PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
   1149   if (py_vspace != nullptr) {
   1150     delete py_vspace;
   1151   }
   1152 
   1153   py_vspace = new PyVSpace(e);
   1154   auto status = py_vspace->Initialize();
   1155   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
   1156     delete py_vspace;
   1157     return nullptr;
   1158   }
   1159 
   1160   Py_RETURN_NONE;
   1161 }
   1162 
   1163 PyObject* PyTapeTensor::GetShape() const {
   1164   if (shape_.index() == 0) {
   1165     auto& shape = absl::get<0>(shape_);
   1166     PyObject* py_shape = PyTuple_New(shape.dims());
   1167     for (int i = 0; i < shape.dims(); ++i) {
   1168       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
   1169     }
   1170 
   1171     return py_shape;
   1172   }
   1173 
   1174   return py_vspace->GraphShape(absl::get<1>(shape_));
   1175 }
   1176 
   1177 class GradientTape
   1178     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
   1179                                              PyTapeTensor> {
   1180  public:
   1181   explicit GradientTape(bool persistent, bool watch_accessed_variables)
   1182       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
   1183                                         PyTapeTensor>(persistent),
   1184         watch_accessed_variables_(watch_accessed_variables) {}
   1185 
   1186   virtual ~GradientTape() {
   1187     for (const IdAndVariable& v : watched_variables_) {
   1188       Py_DECREF(v.variable);
   1189     }
   1190   }
   1191 
   1192   void VariableAccessed(PyObject* v) {
   1193     if (watch_accessed_variables_) {
   1194       WatchVariable(v);
   1195     }
   1196   }
   1197 
   1198   void WatchVariable(PyObject* v) {
   1199     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
   1200     if (handle == nullptr) {
   1201       return;
   1202     }
   1203     tensorflow::int64 id = FastTensorId(handle.get());
   1204 
   1205     if (!PyErr_Occurred()) {
   1206       this->Watch(id);
   1207     }
   1208 
   1209     tensorflow::mutex_lock l(watched_variables_mu_);
   1210     auto insert_result = watched_variables_.emplace(id, v);
   1211 
   1212     if (insert_result.second) {
   1213       // Only increment the reference count if we aren't already watching this
   1214       // variable.
   1215       Py_INCREF(v);
   1216     }
   1217   }
   1218 
   1219   PyObject* GetVariablesAsPyTuple() {
   1220     tensorflow::mutex_lock l(watched_variables_mu_);
   1221     PyObject* result = PyTuple_New(watched_variables_.size());
   1222     Py_ssize_t pos = 0;
   1223     for (const IdAndVariable& id_and_variable : watched_variables_) {
   1224       PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
   1225       Py_INCREF(id_and_variable.variable);
   1226     }
   1227     return result;
   1228   }
   1229 
   1230  private:
   1231   // We store an IdAndVariable in the map since the map needs to be locked
   1232   // during insert, but should not call back into python during insert to avoid
   1233   // deadlocking with the GIL.
   1234   struct IdAndVariable {
   1235     tensorflow::int64 id;
   1236     PyObject* variable;
   1237 
   1238     IdAndVariable(tensorflow::int64 id, PyObject* variable)
   1239         : id(id), variable(variable) {}
   1240   };
   1241   struct CompareById {
   1242     bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
   1243       return lhs.id < rhs.id;
   1244     }
   1245   };
   1246 
   1247   bool watch_accessed_variables_;
   1248   tensorflow::mutex watched_variables_mu_;
   1249   std::set<IdAndVariable, CompareById> watched_variables_
   1250       GUARDED_BY(watched_variables_mu_);
   1251 };
   1252 
   1253 typedef struct {
   1254   PyObject_HEAD
   1255       /* Type-specific fields go here. */
   1256       GradientTape* tape;
   1257 } TFE_Py_Tape;
   1258 
   1259 static void TFE_Py_Tape_Delete(PyObject* tape) {
   1260   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
   1261   Py_TYPE(tape)->tp_free(tape);
   1262 }
   1263 
   1264 static PyTypeObject TFE_Py_Tape_Type = {
   1265     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
   1266     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
   1267     0,                                            /* tp_itemsize */
   1268     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
   1269     nullptr,                                      /* tp_print */
   1270     nullptr,                                      /* tp_getattr */
   1271     nullptr,                                      /* tp_setattr */
   1272     nullptr,                                      /* tp_reserved */
   1273     nullptr,                                      /* tp_repr */
   1274     nullptr,                                      /* tp_as_number */
   1275     nullptr,                                      /* tp_as_sequence */
   1276     nullptr,                                      /* tp_as_mapping */
   1277     nullptr,                                      /* tp_hash  */
   1278     nullptr,                                      /* tp_call */
   1279     nullptr,                                      /* tp_str */
   1280     nullptr,                                      /* tp_getattro */
   1281     nullptr,                                      /* tp_setattro */
   1282     nullptr,                                      /* tp_as_buffer */
   1283     Py_TPFLAGS_DEFAULT,                           /* tp_flags */
   1284     "TFE_Py_Tape objects",                        /* tp_doc */
   1285 };
   1286 
   1287 // Note: in the current design no mutex is needed here because of the python
   1288 // GIL, which is always held when any TFE_Py_* methods are called. We should
   1289 // revisit this if/when decide to not hold the GIL while manipulating the tape
   1290 // stack.
   1291 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
   1292   thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
   1293       nullptr};
   1294   if (tape_set == nullptr) {
   1295     tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
   1296   }
   1297   return tape_set;
   1298 }
   1299 
   1300 // A safe copy of the current tapeset. Does not get affected by other python
   1301 // threads changing the set of active tapes.
   1302 class SafeTapeSet {
   1303  public:
   1304   SafeTapeSet() : tape_set_(*GetTapeSet()) {
   1305     for (auto* tape : tape_set_) {
   1306       Py_INCREF(tape);
   1307     }
   1308   }
   1309 
   1310   ~SafeTapeSet() {
   1311     for (auto* tape : tape_set_) {
   1312       Py_DECREF(tape);
   1313     }
   1314   }
   1315 
   1316   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() {
   1317     return tape_set_.begin();
   1318   }
   1319 
   1320   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() {
   1321     return tape_set_.end();
   1322   }
   1323 
   1324  private:
   1325   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
   1326 };
   1327 
   1328 bool* ThreadTapeIsStopped() {
   1329   thread_local bool thread_tape_is_stopped{false};
   1330   return &thread_tape_is_stopped;
   1331 }
   1332 
   1333 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
   1334 
   1335 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
   1336 
   1337 PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
   1338                             PyObject* watch_accessed_variables) {
   1339   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
   1340   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
   1341   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
   1342   tape->tape = new GradientTape(persistent == Py_True,
   1343                                 watch_accessed_variables == Py_True);
   1344   Py_INCREF(tape);
   1345   GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
   1346   return reinterpret_cast<PyObject*>(tape);
   1347 }
   1348 
   1349 void TFE_Py_TapeSetAdd(PyObject* tape) {
   1350   Py_INCREF(tape);
   1351   if (!GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)).second) {
   1352     // Already exists in the tape set.
   1353     Py_DECREF(tape);
   1354   }
   1355 }
   1356 
   1357 PyObject* TFE_Py_TapeSetIsEmpty() {
   1358   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
   1359     Py_RETURN_TRUE;
   1360   }
   1361   Py_RETURN_FALSE;
   1362 }
   1363 
   1364 void TFE_Py_TapeSetRemove(PyObject* tape) {
   1365   auto* stack = GetTapeSet();
   1366   stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
   1367   // We kept a reference to the tape in the set to ensure it wouldn't get
   1368   // deleted under us; cleaning it up here.
   1369   Py_DECREF(tape);
   1370 }
   1371 
   1372 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
   1373   if (list == Py_None) {
   1374     return {};
   1375   }
   1376   PyObject* seq = PySequence_Fast(list, "expected a sequence");
   1377   if (seq == nullptr) {
   1378     return {};
   1379   }
   1380   int len = PySequence_Size(list);
   1381   std::vector<tensorflow::int64> tensor_ids;
   1382   tensor_ids.reserve(len);
   1383   for (int i = 0; i < len; ++i) {
   1384     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
   1385 #if PY_MAJOR_VERSION >= 3
   1386     if (PyLong_Check(item)) {
   1387 #else
   1388     if (PyLong_Check(item) || PyInt_Check(item)) {
   1389 #endif
   1390       tensorflow::int64 id = MakeInt(item);
   1391       tensor_ids.push_back(id);
   1392     } else {
   1393       tensor_ids.push_back(-1);
   1394     }
   1395   }
   1396   Py_DECREF(seq);
   1397   return tensor_ids;
   1398 }
   1399 
   1400 PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
   1401   if (tensors == Py_None) {
   1402     Py_RETURN_FALSE;
   1403   }
   1404   if (*ThreadTapeIsStopped()) {
   1405     Py_RETURN_FALSE;
   1406   }
   1407   auto* tape_set_ptr = GetTapeSet();
   1408   if (tape_set_ptr->empty()) {
   1409     Py_RETURN_FALSE;
   1410   }
   1411   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1412   if (seq == nullptr) {
   1413     return nullptr;
   1414   }
   1415   int len = PySequence_Fast_GET_SIZE(seq);
   1416   // TODO(apassos) consider not building a list and changing the API to check
   1417   // each tensor individually.
   1418   std::vector<tensorflow::int64> tensor_ids;
   1419   std::vector<tensorflow::DataType> dtypes;
   1420   tensor_ids.reserve(len);
   1421   dtypes.reserve(len);
   1422   for (int i = 0; i < len; ++i) {
   1423     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
   1424     tensor_ids.push_back(FastTensorId(item));
   1425     dtypes.push_back(FastTensorDtype(item));
   1426   }
   1427   Py_DECREF(seq);
   1428   auto tape_set = *tape_set_ptr;
   1429   for (TFE_Py_Tape* tape : tape_set) {
   1430     if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
   1431       Py_RETURN_TRUE;
   1432     }
   1433   }
   1434   Py_RETURN_FALSE;
   1435 }
   1436 
   1437 void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
   1438   if (*ThreadTapeIsStopped()) {
   1439     return;
   1440   }
   1441   tensorflow::int64 tensor_id = FastTensorId(tensor);
   1442   if (PyErr_Occurred()) {
   1443     return;
   1444   }
   1445   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
   1446 }
   1447 
   1448 bool ListContainsNone(PyObject* list) {
   1449   if (list == Py_None) return true;
   1450   tensorflow::Safe_PyObjectPtr seq(
   1451       PySequence_Fast(list, "expected a sequence"));
   1452   if (seq == nullptr) {
   1453     return false;
   1454   }
   1455 
   1456   int len = PySequence_Size(list);
   1457   for (int i = 0; i < len; ++i) {
   1458     PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
   1459     if (item == Py_None) return true;
   1460   }
   1461 
   1462   return false;
   1463 }
   1464 
   1465 static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
   1466   if (EagerTensor_CheckExact(tensor)) {
   1467     TFE_TensorHandle* t = EagerTensor_Handle(tensor);
   1468     tensorflow::int64 id = PyEagerTensor_ID(tensor);
   1469     tensorflow::TensorShape tensor_shape;
   1470     const tensorflow::Status status = t->handle->Shape(&tensor_shape);
   1471 
   1472     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
   1473       return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
   1474                           tensorflow::TensorShape({}));
   1475     } else {
   1476       return PyTapeTensor(id, t->handle->dtype, tensor_shape);
   1477     }
   1478   }
   1479   tensorflow::int64 id = FastTensorId(tensor);
   1480   if (PyErr_Occurred()) {
   1481     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
   1482                         tensorflow::TensorShape({}));
   1483   }
   1484   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
   1485   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
   1486   Py_DECREF(dtype_object);
   1487   tensorflow::DataType dtype =
   1488       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
   1489   Py_DECREF(dtype_enum);
   1490   if (PyErr_Occurred()) {
   1491     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
   1492                         tensorflow::TensorShape({}));
   1493   }
   1494   static char _shape_tuple[] = "_shape_tuple";
   1495   PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
   1496   if (PyErr_Occurred()) {
   1497     return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
   1498                         tensorflow::TensorShape({}));
   1499   }
   1500 
   1501   if (ListContainsNone(shape_tuple)) {
   1502     return PyTapeTensor(id, dtype, tensor);
   1503   }
   1504 
   1505   auto l = MakeIntList(shape_tuple);
   1506   Py_DECREF(shape_tuple);
   1507   // Replace -1, which represents accidental Nones which can occur in graph mode
   1508   // and can cause errors in shape cosntruction with 0s.
   1509   for (auto& c : l) {
   1510     if (c < 0) {
   1511       c = 0;
   1512     }
   1513   }
   1514   tensorflow::TensorShape shape(l);
   1515   return PyTapeTensor(id, dtype, shape);
   1516 }
   1517 
   1518 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
   1519   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1520   if (seq == nullptr) {
   1521     return {};
   1522   }
   1523   int len = PySequence_Fast_GET_SIZE(seq);
   1524   std::vector<tensorflow::int64> list;
   1525   list.reserve(len);
   1526   for (int i = 0; i < len; ++i) {
   1527     PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
   1528     list.push_back(FastTensorId(tensor));
   1529     if (PyErr_Occurred()) {
   1530       Py_DECREF(seq);
   1531       return list;
   1532     }
   1533   }
   1534   Py_DECREF(seq);
   1535   return list;
   1536 }
   1537 
   1538 void TFE_Py_TapeVariableAccessed(PyObject* variable) {
   1539   if (*ThreadTapeIsStopped()) {
   1540     return;
   1541   }
   1542   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1543     tape->tape->VariableAccessed(variable);
   1544   }
   1545 }
   1546 
   1547 void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
   1548   if (*ThreadTapeIsStopped()) {
   1549     return;
   1550   }
   1551   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
   1552 }
   1553 
   1554 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
   1555   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
   1556 }
   1557 
   1558 namespace {
   1559 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
   1560   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1561   if (seq == nullptr) {
   1562     return {};
   1563   }
   1564   int len = PySequence_Fast_GET_SIZE(seq);
   1565   std::vector<tensorflow::DataType> list;
   1566   list.reserve(len);
   1567   for (int i = 0; i < len; ++i) {
   1568     PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
   1569     list.push_back(FastTensorDtype(tensor));
   1570   }
   1571   Py_DECREF(seq);
   1572   return list;
   1573 }
   1574 
   1575 void TapeSetRecordOperation(
   1576     PyObject* op_type, PyObject* output_tensors,
   1577     const std::vector<tensorflow::int64>& input_ids,
   1578     const std::vector<tensorflow::DataType>& input_dtypes,
   1579     const std::function<PyBackwardFunction*()>& backward_function_getter,
   1580     const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
   1581   std::vector<PyTapeTensor> output_info;
   1582   PyObject* seq = PySequence_Fast(output_tensors,
   1583                                   "expected a sequence of integer tensor ids");
   1584   int len = PySequence_Size(output_tensors);
   1585   if (PyErr_Occurred()) return;
   1586   output_info.reserve(len);
   1587   for (int i = 0; i < len; ++i) {
   1588     output_info.push_back(
   1589         TapeTensorFromTensor(PySequence_Fast_GET_ITEM(seq, i)));
   1590     if (PyErr_Occurred() != nullptr) {
   1591       Py_DECREF(seq);
   1592       return;
   1593     }
   1594   }
   1595   Py_DECREF(seq);
   1596   string op_type_str;
   1597   if (PyBytes_Check(op_type)) {
   1598     op_type_str = PyBytes_AsString(op_type);
   1599   } else if (PyUnicode_Check(op_type)) {
   1600 #if PY_MAJOR_VERSION >= 3
   1601     op_type_str = PyUnicode_AsUTF8(op_type);
   1602 #else
   1603     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
   1604     if (py_str == nullptr) return;
   1605     op_type_str = PyBytes_AS_STRING(py_str);
   1606     Py_DECREF(py_str);
   1607 #endif
   1608   } else {
   1609     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
   1610     return;
   1611   }
   1612 
   1613   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1614     tape->tape->RecordOperation(op_type_str, output_info, input_ids,
   1615                                 input_dtypes, backward_function_getter,
   1616                                 backward_function_killer);
   1617   }
   1618 }
   1619 }  // namespace
   1620 
   1621 void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
   1622                                    PyObject* input_tensors,
   1623                                    PyObject* backward_function) {
   1624   if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
   1625     return;
   1626   }
   1627   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
   1628   if (PyErr_Occurred()) return;
   1629 
   1630   std::vector<tensorflow::DataType> input_dtypes =
   1631       MakeTensorDtypeList(input_tensors);
   1632   if (PyErr_Occurred()) return;
   1633 
   1634   TapeSetRecordOperation(
   1635       op_type, output_tensors, input_ids, input_dtypes,
   1636       [backward_function]() {
   1637         Py_INCREF(backward_function);
   1638         PyBackwardFunction* function =
   1639             new PyBackwardFunction([backward_function](PyObject* out_grads) {
   1640               return PyObject_CallObject(backward_function, out_grads);
   1641             });
   1642         return function;
   1643       },
   1644       [backward_function](PyBackwardFunction* py_backward_function) {
   1645         Py_DECREF(backward_function);
   1646         delete py_backward_function;
   1647       });
   1648 }
   1649 
   1650 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
   1651   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1652     tape->tape->DeleteTrace(tensor_id);
   1653   }
   1654 }
   1655 
   1656 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   1657   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1658   if (seq == nullptr) {
   1659     return {};
   1660   }
   1661   int len = PySequence_Fast_GET_SIZE(seq);
   1662   std::vector<PyObject*> list;
   1663   list.reserve(len);
   1664   for (int i = 0; i < len; ++i) {
   1665     list.push_back(PySequence_Fast_GET_ITEM(seq, i));
   1666   }
   1667   Py_DECREF(seq);
   1668   return list;
   1669 }
   1670 
   1671 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
   1672                               PyObject* sources, PyObject* output_gradients,
   1673                               PyObject* unconnected_gradients,
   1674                               TF_Status* status) {
   1675   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
   1676   if (!tape_obj->tape->IsPersistent()) {
   1677     auto* tape_set = GetTapeSet();
   1678     if (tape_set->find(tape_obj) != tape_set->end()) {
   1679       PyErr_SetString(PyExc_RuntimeError,
   1680                       "gradient() cannot be invoked within the "
   1681                       "GradientTape context (i.e., while operations are being "
   1682                       "recorded). Either move the call to gradient() to be "
   1683                       "outside the 'with tf.GradientTape' block, or "
   1684                       "use a persistent tape: "
   1685                       "'with tf.GradientTape(persistent=true)'");
   1686       return nullptr;
   1687     }
   1688   }
   1689 
   1690   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
   1691   if (PyErr_Occurred()) {
   1692     return nullptr;
   1693   }
   1694   std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
   1695   if (PyErr_Occurred()) {
   1696     return nullptr;
   1697   }
   1698   tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
   1699                                                           sources_vec.end());
   1700 
   1701   tensorflow::Safe_PyObjectPtr seq =
   1702       tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
   1703   int len = PySequence_Fast_GET_SIZE(seq.get());
   1704   tensorflow::gtl::FlatMap<tensorflow::int64, PyTapeTensor>
   1705       source_tensors_that_are_targets;
   1706   for (int i = 0; i < len; ++i) {
   1707     tensorflow::int64 target_id = target_vec[i];
   1708     if (sources_set.find(target_id) != sources_set.end()) {
   1709       auto tensor = PySequence_Fast_GET_ITEM(seq.get(), i);
   1710       source_tensors_that_are_targets.insert(
   1711           std::make_pair(target_id, TapeTensorFromTensor(tensor)));
   1712     }
   1713     if (PyErr_Occurred()) {
   1714       return nullptr;
   1715     }
   1716   }
   1717   if (PyErr_Occurred()) {
   1718     return nullptr;
   1719   }
   1720 
   1721   std::vector<PyObject*> outgrad_vec;
   1722   if (output_gradients != Py_None) {
   1723     outgrad_vec = MakeTensorList(output_gradients);
   1724     if (PyErr_Occurred()) {
   1725       return nullptr;
   1726     }
   1727     for (PyObject* tensor : outgrad_vec) {
   1728       // Calling the backward function will eat a reference to the tensors in
   1729       // outgrad_vec, so we need to increase their reference count.
   1730       Py_INCREF(tensor);
   1731     }
   1732   }
   1733   std::vector<PyObject*> result;
   1734   status->status = tape_obj->tape->ComputeGradient(
   1735       *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
   1736       outgrad_vec, &result);
   1737   if (!status->status.ok()) {
   1738     if (PyErr_Occurred()) {
   1739       // Do not propagate the erroneous status as that would swallow the
   1740       // exception which caused the problem.
   1741       status->status = tensorflow::Status::OK();
   1742     }
   1743     return nullptr;
   1744   }
   1745 
   1746   bool unconnected_gradients_zero =
   1747       strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
   1748   std::vector<PyObject*> sources_obj;
   1749   if (unconnected_gradients_zero) {
   1750     sources_obj = MakeTensorList(sources);
   1751   }
   1752 
   1753   if (!result.empty()) {
   1754     PyObject* py_result = PyList_New(result.size());
   1755     tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
   1756     for (int i = 0; i < result.size(); ++i) {
   1757       if (result[i] == nullptr) {
   1758         if (unconnected_gradients_zero) {
   1759           // generate a zeros tensor in the shape of sources[i]
   1760           tensorflow::DataType dtype = FastTensorDtype(sources_obj[i]);
   1761           PyTapeTensor tensor =
   1762               PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
   1763           result[i] = py_vspace->Zeros(tensor);
   1764         } else {
   1765           Py_INCREF(Py_None);
   1766           result[i] = Py_None;
   1767         }
   1768       } else if (seen_results.find(result[i]) != seen_results.end()) {
   1769         Py_INCREF(result[i]);
   1770       }
   1771       seen_results.insert(result[i]);
   1772       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
   1773     }
   1774     return py_result;
   1775   }
   1776   return PyList_New(0);
   1777 }
   1778 
   1779 namespace {
   1780 static const int kFastPathExecuteInputStartIndex = 5;
   1781 
   1782 PyObject* GetPythonObjectFromString(const char* s) {
   1783 #if PY_MAJOR_VERSION >= 3
   1784   return PyUnicode_FromString(s);
   1785 #else
   1786   return PyBytes_FromString(s);
   1787 #endif
   1788 }
   1789 
   1790 PyObject* GetPythonObjectFromInt(int num) {
   1791 #if PY_MAJOR_VERSION >= 3
   1792   return PyLong_FromLong(num);
   1793 #else
   1794   return PyInt_FromLong(num);
   1795 #endif
   1796 }
   1797 
   1798 bool CheckResourceVariable(PyObject* item) {
   1799   return PyObject_TypeCheck(item, resource_variable_type);
   1800 }
   1801 
   1802 bool IsNumberType(PyObject* item) {
   1803 #if PY_MAJOR_VERSION >= 3
   1804   return PyFloat_Check(item) || PyLong_Check(item);
   1805 #else
   1806   return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
   1807 #endif
   1808 }
   1809 
   1810 bool CheckOneInput(PyObject* item) {
   1811   if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
   1812       PyArray_Check(item) || IsNumberType(item)) {
   1813     return true;
   1814   }
   1815 
   1816   // Sequences are not properly handled. Sequences with purely python numeric
   1817   // types work, but sequences with mixes of EagerTensors and python numeric
   1818   // types don't work.
   1819   // TODO(nareshmodi): fix
   1820   return false;
   1821 }
   1822 
   1823 bool CheckInputsOk(PyObject* seq, int start_index,
   1824                    const tensorflow::OpDef& op_def) {
   1825   for (int i = 0; i < op_def.input_arg_size(); i++) {
   1826     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
   1827     if (!op_def.input_arg(i).number_attr().empty() ||
   1828         !op_def.input_arg(i).type_list_attr().empty()) {
   1829       // This item should be a seq input.
   1830       if (!PySequence_Check(item)) {
   1831         VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
   1832                 << "\", Input \"" << op_def.input_arg(i).name()
   1833                 << "\" since we expected a sequence, but got "
   1834                 << item->ob_type->tp_name;
   1835         return false;
   1836       }
   1837       for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
   1838         PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
   1839         if (!CheckOneInput(inner_item)) {
   1840           VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
   1841                   << "\", Input \"" << op_def.input_arg(i).name()
   1842                   << "\", Index " << j
   1843                   << " since we expected an EagerTensor/ResourceVariable, "
   1844                      "but got "
   1845                   << inner_item->ob_type->tp_name;
   1846           return false;
   1847         }
   1848       }
   1849     } else if (!CheckOneInput(item)) {
   1850       VLOG(1)
   1851           << "Falling back to slow path for Op \"" << op_def.name()
   1852           << "\", Input \"" << op_def.input_arg(i).name()
   1853           << "\" since we expected an EagerTensor/ResourceVariable, but got "
   1854           << item->ob_type->tp_name;
   1855       return false;
   1856     }
   1857   }
   1858 
   1859   return true;
   1860 }
   1861 
   1862 PyObject* MaybeGetDType(PyObject* item) {
   1863   if (EagerTensor_CheckExact(item)) {
   1864     tensorflow::Safe_PyObjectPtr py_dtype(
   1865         PyObject_GetAttrString(item, "dtype"));
   1866     return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
   1867   }
   1868 
   1869   if (CheckResourceVariable(item)) {
   1870     tensorflow::Safe_PyObjectPtr py_dtype(
   1871         PyObject_GetAttrString(item, "_dtype"));
   1872     return PyObject_GetAttrString(py_dtype.get(), "_type_enum");
   1873   }
   1874 
   1875   return nullptr;
   1876 }
   1877 
   1878 PyObject* MaybeGetDTypeForAttr(const string& attr,
   1879                                FastPathOpExecInfo* op_exec_info) {
   1880   auto cached_it = op_exec_info->cached_dtypes.find(attr);
   1881   if (cached_it != op_exec_info->cached_dtypes.end()) {
   1882     return GetPythonObjectFromInt(cached_it->second);
   1883   }
   1884 
   1885   auto it = op_exec_info->attr_to_inputs_map->find(attr);
   1886   if (it == op_exec_info->attr_to_inputs_map->end()) {
   1887     // No other inputs - this should never happen.
   1888     Py_RETURN_NONE;
   1889   }
   1890 
   1891   for (const auto& input_info : it->second) {
   1892     PyObject* item = PyTuple_GET_ITEM(
   1893         op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
   1894     if (input_info.is_list) {
   1895       for (int i = 0; i < PySequence_Fast_GET_SIZE(item); i++) {
   1896         auto* dtype = MaybeGetDType(PySequence_Fast_GET_ITEM(item, i));
   1897         if (dtype != nullptr) return dtype;
   1898       }
   1899     } else {
   1900       auto* dtype = MaybeGetDType(item);
   1901       if (dtype != nullptr) return dtype;
   1902     }
   1903   }
   1904 
   1905   Py_RETURN_NONE;
   1906 }
   1907 
   1908 // TODO(agarwal): use an automatic mechanism for handling None arguments to
   1909 // gradient functions.
   1910 
   1911 // Returns a pair where the first value of the pair indicates whether or not all
   1912 // outputs are unused. If the first value is false, the second value is a
   1913 // set that identifies which of the output indices are unused.
   1914 bool OpGradientDoesntRequireOutputIndices(
   1915     const string& op_name,
   1916     std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
   1917   static tensorflow::gtl::FlatMap<
   1918       string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
   1919       new tensorflow::gtl::FlatMap<
   1920           string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
   1921           // Ops that don't require any outputs.
   1922           {"Identity", {true, {}}},
   1923           {"MatMul", {true, {}}},
   1924           {"Conv2DBackpropInput", {true, {}}},
   1925           {"Conv2DBackpropFilter", {true, {}}},
   1926           {"Conv3D", {true, {}}},
   1927           {"Conv3DBackpropInputV2", {true, {}}},
   1928           {"AvgPool3D", {true, {}}},
   1929           {"AvgPool3DGrad", {true, {}}},
   1930           {"MaxPool3D", {false, {}}},
   1931           {"MaxPool3DGrad", {true, {}}},
   1932           {"MaxPool3DGradGrad", {true, {}}},
   1933           {"BiasAdd", {true, {}}},
   1934           {"BiasAddV1", {true, {}}},
   1935           {"BiasAddGrad", {true, {}}},
   1936           {"Softplus", {true, {}}},
   1937           {"SoftplusGrad", {true, {}}},
   1938           {"Softsign", {true, {}}},
   1939           {"ReluGrad", {true, {}}},
   1940           {"LeakyRelu", {true, {}}},
   1941           {"LeakyReluGrad", {true, {}}},
   1942           {"Conv2D", {true, {}}},
   1943           {"DepthwiseConv2dNative", {true, {}}},
   1944           {"Dilation2D", {true, {}}},
   1945           {"AvgPool", {true, {}}},
   1946           {"AvgPoolGrad", {true, {}}},
   1947           {"BatchNormWithGlobalNormalization", {true, {}}},
   1948           {"L2Loss", {true, {}}},
   1949           {"Sum", {true, {}}},
   1950           {"Prod", {true, {}}},
   1951           {"SegmentSum", {true, {}}},
   1952           {"SegmentMean", {true, {}}},
   1953           {"SparseSegmentSum", {true, {}}},
   1954           {"SparseSegmentMean", {true, {}}},
   1955           {"SparseSegmentSqrtN", {true, {}}},
   1956           {"UnsortedSegmentSum", {true, {}}},
   1957           {"UnsortedSegmentMax", {true, {}}},
   1958           {"Abs", {true, {}}},
   1959           {"Neg", {true, {}}},
   1960           {"ReciprocalGrad", {true, {}}},
   1961           {"Square", {true, {}}},
   1962           {"Expm1", {true, {}}},
   1963           {"Log", {true, {}}},
   1964           {"Log1p", {true, {}}},
   1965           {"TanhGrad", {true, {}}},
   1966           {"SigmoidGrad", {true, {}}},
   1967           {"Sign", {true, {}}},
   1968           {"Sin", {true, {}}},
   1969           {"Cos", {true, {}}},
   1970           {"Tan", {true, {}}},
   1971           {"Add", {true, {}}},
   1972           {"Sub", {true, {}}},
   1973           {"Mul", {true, {}}},
   1974           {"Div", {true, {}}},
   1975           {"RealDiv", {true, {}}},
   1976           {"Maximum", {true, {}}},
   1977           {"Minimum", {true, {}}},
   1978           {"SquaredDifference", {true, {}}},
   1979           {"Select", {true, {}}},
   1980           {"SparseMatMul", {true, {}}},
   1981           {"BatchMatMul", {true, {}}},
   1982           {"Complex", {true, {}}},
   1983           {"Real", {true, {}}},
   1984           {"Imag", {true, {}}},
   1985           {"Angle", {true, {}}},
   1986           {"Conj", {true, {}}},
   1987           {"Cast", {true, {}}},
   1988           {"Cross", {true, {}}},
   1989           {"Cumsum", {true, {}}},
   1990           {"Cumprod", {true, {}}},
   1991           {"ReadVariableOp", {true, {}}},
   1992           {"VarHandleOp", {true, {}}},
   1993           {"Shape", {true, {}}},
   1994           {"StridedSlice", {true, {}}},
   1995           {"Fill", {true, {}}},
   1996 
   1997           // Ops that don't require a subset of outputs.
   1998           {"FusedBatchNorm", {false, {0, 1, 2}}},
   1999       });
   2000 
   2001   auto it = m->find(op_name);
   2002 
   2003   if (it == m->end()) return false;
   2004 
   2005   *output = &it->second;
   2006   return true;
   2007 }
   2008 
   2009 // Returns a pair where the first value of the pair indicates whether or not all
   2010 // inputs are unused. If the first value is false, the second value is a
   2011 // set that identifies which of the input indices are unused.
   2012 bool OpGradientDoesntRequireInputIndices(
   2013     const string& op_name,
   2014     std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
   2015   static tensorflow::gtl::FlatMap<
   2016       string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
   2017       new tensorflow::gtl::FlatMap<
   2018           string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
   2019           // Ops that don't require any inputs.
   2020           {"Identity", {true, {}}},
   2021           {"Softmax", {true, {}}},
   2022           {"LogSoftmax", {true, {}}},
   2023           {"BiasAdd", {true, {}}},
   2024           {"Relu", {true, {}}},
   2025           {"Relu6", {true, {}}},
   2026           {"Elu", {true, {}}},
   2027           {"Selu", {true, {}}},
   2028           {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
   2029           {"Neg", {true, {}}},
   2030           {"Inv", {true, {}}},
   2031           {"Reciprocal", {true, {}}},
   2032           {"Sqrt", {true, {}}},
   2033           {"Exp", {true, {}}},
   2034           {"Tanh", {true, {}}},
   2035           {"Sigmoid", {true, {}}},
   2036           {"Real", {true, {}}},
   2037           {"Imag", {true, {}}},
   2038           {"Conj", {true, {}}},
   2039           {"ReadVariableOp", {true, {}}},
   2040           {"VarHandleOp", {true, {}}},
   2041           {"Shape", {true, {}}},
   2042           {"Fill", {true, {}}},
   2043 
   2044           // Ops that don't require a subset of inputs.
   2045           {"FusedBatchNorm", {false, {2}}},
   2046       });
   2047 
   2048   auto it = m->find(op_name);
   2049 
   2050   if (it == m->end()) return false;
   2051 
   2052   *output = &it->second;
   2053   return true;
   2054 }
   2055 
   2056 PyObject* CopySequenceSettingIndicesToNull(
   2057     PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
   2058   tensorflow::Safe_PyObjectPtr fast_seq(
   2059       PySequence_Fast(seq, "unable to allocate"));
   2060   PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
   2061   for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
   2062     PyObject* item;
   2063     if (indices.find(i) != indices.end()) {
   2064       item = Py_None;
   2065     } else {
   2066       item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
   2067     }
   2068     Py_INCREF(item);
   2069     PyTuple_SET_ITEM(result, i, item);
   2070   }
   2071   return result;
   2072 }
   2073 
   2074 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
   2075                          PyObject* results, PyObject* name) {
   2076   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
   2077   if (PyErr_Occurred()) return nullptr;
   2078   std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
   2079   if (PyErr_Occurred()) return nullptr;
   2080 
   2081   bool should_record = false;
   2082   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   2083     if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
   2084       should_record = true;
   2085       break;
   2086     }
   2087   }
   2088   if (!should_record) Py_RETURN_NONE;
   2089 
   2090   string c_op_name = TFE_GetPythonString(op_name);
   2091 
   2092   PyObject* op_outputs;
   2093   bool op_outputs_tuple_created = false;
   2094   std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
   2095 
   2096   if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
   2097     if (outputs_not_required->first) {
   2098       op_outputs = Py_None;
   2099     } else {
   2100       op_outputs_tuple_created = true;
   2101       op_outputs = CopySequenceSettingIndicesToNull(
   2102           results, outputs_not_required->second);
   2103     }
   2104   } else {
   2105     op_outputs = results;
   2106   }
   2107 
   2108   PyObject* op_inputs;
   2109   bool op_inputs_tuple_created = false;
   2110   std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
   2111 
   2112   if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
   2113     if (inputs_not_required->first) {
   2114       op_inputs = Py_None;
   2115     } else {
   2116       op_inputs_tuple_created = true;
   2117       op_inputs =
   2118           CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
   2119     }
   2120   } else {
   2121     op_inputs = inputs;
   2122   }
   2123 
   2124   PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
   2125 
   2126   TapeSetRecordOperation(
   2127       op_name, results, input_ids, input_dtypes,
   2128       [op_name, attrs, num_inputs, op_inputs, op_outputs]() {
   2129         Py_INCREF(op_name);
   2130         Py_INCREF(attrs);
   2131         Py_INCREF(num_inputs);
   2132         Py_INCREF(op_inputs);
   2133         Py_INCREF(op_outputs);
   2134         PyBackwardFunction* function =
   2135             new PyBackwardFunction([op_name, attrs, num_inputs, op_inputs,
   2136                                     op_outputs](PyObject* output_grads) {
   2137               if (PyErr_Occurred()) {
   2138                 return static_cast<PyObject*>(nullptr);
   2139               }
   2140               tensorflow::Safe_PyObjectPtr callback_args(
   2141                   Py_BuildValue("OOOOOO", op_name, attrs, num_inputs, op_inputs,
   2142                                 op_outputs, output_grads));
   2143 
   2144               tensorflow::Safe_PyObjectPtr result(
   2145                   PyObject_CallObject(gradient_function, callback_args.get()));
   2146 
   2147               if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
   2148 
   2149               return tensorflow::swig::Flatten(result.get());
   2150             });
   2151         return function;
   2152       },
   2153       [op_name, attrs, num_inputs, op_inputs,
   2154        op_outputs](PyBackwardFunction* backward_function) {
   2155         Py_DECREF(op_name);
   2156         Py_DECREF(attrs);
   2157         Py_DECREF(num_inputs);
   2158         Py_DECREF(op_inputs);
   2159         Py_DECREF(op_outputs);
   2160 
   2161         delete backward_function;
   2162       });
   2163 
   2164   Py_DECREF(num_inputs);
   2165   if (op_outputs_tuple_created) Py_DECREF(op_outputs);
   2166   if (op_inputs_tuple_created) Py_DECREF(op_inputs);
   2167 
   2168   Py_RETURN_NONE;
   2169 }
   2170 
   2171 void MaybeNotifyVariableAccessed(PyObject* input) {
   2172   DCHECK(CheckResourceVariable(input));
   2173   DCHECK(PyObject_HasAttrString(input, "_trainable"));
   2174 
   2175   tensorflow::Safe_PyObjectPtr trainable(
   2176       PyObject_GetAttrString(input, "_trainable"));
   2177   if (trainable.get() == Py_False) return;
   2178   TFE_Py_TapeVariableAccessed(input);
   2179 }
   2180 
   2181 bool CastTensor(const FastPathOpExecInfo& op_exec_info,
   2182                 const TF_DataType& desired_dtype,
   2183                 tensorflow::Safe_TFE_TensorHandlePtr* handle,
   2184                 TF_Status* status) {
   2185   TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
   2186   TF_DataType output_dtype = input_dtype;
   2187 
   2188   if (desired_dtype >= 0 && desired_dtype != input_dtype) {
   2189     *handle = tensorflow::make_safe(
   2190         tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
   2191                               static_cast<TF_DataType>(desired_dtype), status));
   2192     if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2193       return false;
   2194     }
   2195     output_dtype = desired_dtype;
   2196   }
   2197 
   2198   if (output_dtype != TF_INT32) {
   2199     // Note that this is a shallow copy and will share the underlying buffer
   2200     // if copying to the same device.
   2201     *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
   2202         handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
   2203     if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2204       return false;
   2205     }
   2206   }
   2207   return true;
   2208 }
   2209 
   2210 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
   2211                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
   2212                     TF_Status* status) {
   2213   MaybeNotifyVariableAccessed(input);
   2214 
   2215   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
   2216   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
   2217   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
   2218 
   2219   // Set dtype
   2220   DCHECK(PyObject_HasAttrString(input, "_dtype"));
   2221   tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
   2222   int value;
   2223   if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
   2224     return false;
   2225   }
   2226   TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
   2227 
   2228   TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
   2229   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
   2230 
   2231   // Get handle
   2232   tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
   2233   if (!EagerTensor_CheckExact(handle.get())) return false;
   2234   TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
   2235   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
   2236 
   2237   int num_retvals = 1;
   2238   TFE_TensorHandle* output_handle;
   2239   TFE_Execute(op, &output_handle, &num_retvals, status);
   2240   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
   2241 
   2242   // Always create the py object (and correctly DECREF it) from the returned
   2243   // value, else the data will leak.
   2244   output->reset(EagerTensorFromHandle(output_handle));
   2245 
   2246   // TODO(nareshmodi): Should we run post exec callbacks here?
   2247   if (parent_op_exec_info.run_gradient_callback) {
   2248     tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
   2249     PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
   2250 
   2251     tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
   2252     Py_INCREF(output->get());  // stay alive after since tuple steals.
   2253     PyTuple_SET_ITEM(outputs.get(), 0, output->get());
   2254 
   2255     tensorflow::Safe_PyObjectPtr op_string(
   2256         GetPythonObjectFromString("ReadVariableOp"));
   2257     if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(),
   2258                         Py_None)) {
   2259       return false;
   2260     }
   2261   }
   2262 
   2263   return true;
   2264 }
   2265 
   2266 // Supports 3 cases at the moment:
   2267 //  i) input is an EagerTensor.
   2268 //  ii) input is a ResourceVariable - in this case, the is_variable param is
   2269 //  set to true.
   2270 //  iii) input is an arbitrary python list/tuple (note, this handling doesn't
   2271 //  support packing).
   2272 //
   2273 //  NOTE: dtype_hint_getter must *always* return a PyObject that can be
   2274 //  decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
   2275 //  increfs Py_None).
   2276 //
   2277 //  NOTE: This function sets a python error directly, and returns false.
   2278 //  TF_Status is only passed since we don't want to have to reallocate it.
   2279 bool ConvertToTensor(
   2280     const FastPathOpExecInfo& op_exec_info, PyObject* input,
   2281     tensorflow::Safe_PyObjectPtr* output_handle,
   2282     // This gets a hint for this particular input.
   2283     const std::function<PyObject*()>& dtype_hint_getter,
   2284     // This sets the dtype after conversion is complete.
   2285     const std::function<void(const TF_DataType& dtype)>& dtype_setter,
   2286     TF_Status* status) {
   2287   if (EagerTensor_CheckExact(input)) {
   2288     Py_INCREF(input);
   2289     output_handle->reset(input);
   2290     return true;
   2291   } else if (CheckResourceVariable(input)) {
   2292     return ReadVariableOp(op_exec_info, input, output_handle, status);
   2293   }
   2294 
   2295   // The hint comes from a supposedly similarly typed tensor.
   2296   tensorflow::Safe_PyObjectPtr dtype_hint(dtype_hint_getter());
   2297   if (PyErr_Occurred()) {
   2298     return false;
   2299   }
   2300 
   2301   tensorflow::Safe_TFE_TensorHandlePtr handle =
   2302       tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
   2303           tensorflow::ConvertToEagerTensor(input, dtype_hint.get())));
   2304   if (handle == nullptr) {
   2305     return MaybeRaiseExceptionFromTFStatus(status, nullptr);
   2306   }
   2307 
   2308   int desired_dtype = -1;
   2309   if (dtype_hint.get() != Py_None) {
   2310     if (!ParseTypeValue("", dtype_hint.get(), status, &desired_dtype)) {
   2311       PyErr_SetString(PyExc_TypeError,
   2312                       tensorflow::strings::StrCat(
   2313                           "Expecting a DataType value for dtype. Got ",
   2314                           Py_TYPE(dtype_hint.get())->tp_name)
   2315                           .c_str());
   2316       return false;
   2317     }
   2318   }
   2319 
   2320   // Maybe cast to the desired type. This is intended to match python
   2321   // convert_to_tensor behavior.
   2322   TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
   2323   if (desired_dtype >= 0 && desired_dtype != output_dtype) {
   2324     if (tensorflow::IsCompatible(desired_dtype, output_dtype)) {
   2325       if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
   2326                       &handle, status)) {
   2327         return false;
   2328       }
   2329       output_dtype = TFE_TensorHandleDataType(handle.get());
   2330     } else {
   2331       tensorflow::Safe_PyObjectPtr input_str(PyObject_Str(input));
   2332       PyErr_SetString(
   2333           PyExc_TypeError,
   2334           tensorflow::strings::StrCat(
   2335               "Cannot convert provided value to EagerTensor. Provided value: ",
   2336               TFE_GetPythonString(input_str.get()), " Requested dtype: ",
   2337               tensorflow::DataTypeString(
   2338                   static_cast<tensorflow::DataType>(desired_dtype)))
   2339               .c_str());
   2340       return false;
   2341     }
   2342   }
   2343 
   2344   output_handle->reset(EagerTensorFromHandle(handle.release()));
   2345   dtype_setter(output_dtype);
   2346 
   2347   return true;
   2348 }
   2349 
   2350 // Adds input and type attr to the op, and to the list of flattened
   2351 // inputs/attrs.
   2352 bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
   2353                   const bool add_type_attr,
   2354                   const tensorflow::OpDef::ArgDef& input_arg,
   2355                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
   2356                   std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
   2357                   TFE_Op* op, TF_Status* status) {
   2358   // py_eager_tensor's ownership is transferred to flattened_inputs if it is
   2359   // required, else the object is destroyed and DECREF'd when the object goes
   2360   // out of scope in this function.
   2361   tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
   2362 
   2363   if (!ConvertToTensor(
   2364           *op_exec_info, input, &py_eager_tensor,
   2365           [&]() {
   2366             if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
   2367               return GetPythonObjectFromInt(input_arg.type());
   2368             }
   2369             return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
   2370           },
   2371           [&](const TF_DataType dtype) {
   2372             op_exec_info->cached_dtypes[input_arg.type_attr()] =
   2373                 static_cast<tensorflow::DataType>(dtype);
   2374           },
   2375           status)) {
   2376     return false;
   2377   }
   2378 
   2379   TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
   2380 
   2381   if (add_type_attr && !input_arg.type_attr().empty()) {
   2382     auto dtype = TFE_TensorHandleDataType(input_handle);
   2383     TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
   2384     if (flattened_attrs != nullptr) {
   2385       flattened_attrs->emplace_back(
   2386           GetPythonObjectFromString(input_arg.type_attr().data()));
   2387       flattened_attrs->emplace_back(PyLong_FromLong(dtype));
   2388     }
   2389   }
   2390 
   2391   if (flattened_inputs != nullptr) {
   2392     flattened_inputs->emplace_back(std::move(py_eager_tensor));
   2393   }
   2394 
   2395   TFE_OpAddInput(op, input_handle, status);
   2396   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2397     return false;
   2398   }
   2399 
   2400   return true;
   2401 }
   2402 
   2403 const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
   2404   const char* op_name = TFE_GetPythonString(py_op_name);
   2405   if (op_name == nullptr) {
   2406     PyErr_SetString(PyExc_TypeError,
   2407                     Printf("expected a string for op_name, got %s instead",
   2408                            py_op_name->ob_type->tp_name)
   2409                         .c_str());
   2410     return nullptr;
   2411   }
   2412 
   2413   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
   2414   const tensorflow::Status lookup_status =
   2415       tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
   2416   if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) {
   2417     return nullptr;
   2418   }
   2419   return &op_reg_data->op_def;
   2420 }
   2421 
   2422 const char* GetDeviceName(PyObject* py_device_name) {
   2423   if (py_device_name != Py_None) {
   2424     return TFE_GetPythonString(py_device_name);
   2425   }
   2426   return nullptr;
   2427 }
   2428 
   2429 bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
   2430   if (!PySequence_Check(seq)) {
   2431     PyErr_SetString(PyExc_TypeError,
   2432                     Printf("expected a sequence for attr %s, got %s instead",
   2433                            attr_name.data(), seq->ob_type->tp_name)
   2434                         .data());
   2435 
   2436     return false;
   2437   }
   2438   return true;
   2439 }
   2440 
   2441 bool RunCallbacks(
   2442     const FastPathOpExecInfo& op_exec_info, PyObject* args,
   2443     const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_inputs,
   2444     const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_attrs,
   2445     PyObject* flattened_result) {
   2446   if (!op_exec_info.run_callbacks) return true;
   2447 
   2448   tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs->size()));
   2449   for (int i = 0; i < flattened_inputs->size(); i++) {
   2450     PyObject* input = (*flattened_inputs)[i].get();
   2451     Py_INCREF(input);
   2452     PyTuple_SET_ITEM(inputs.get(), i, input);
   2453   }
   2454 
   2455   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
   2456                                op_exec_info.op_def->input_arg_size() -
   2457                                kFastPathExecuteInputStartIndex;
   2458   int num_attrs = flattened_attrs->size() + num_non_inferred_attrs;
   2459   tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
   2460 
   2461   for (int i = 0; i < num_non_inferred_attrs; i++) {
   2462     auto* attr =
   2463         PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex +
   2464                                    op_exec_info.op_def->input_arg_size() + i);
   2465     Py_INCREF(attr);
   2466     PyTuple_SET_ITEM(attrs.get(), i, attr);
   2467   }
   2468   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
   2469     PyObject* attr_or_name =
   2470         flattened_attrs->at(i - num_non_inferred_attrs).get();
   2471     Py_INCREF(attr_or_name);
   2472     PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
   2473   }
   2474 
   2475   if (op_exec_info.run_gradient_callback) {
   2476     if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
   2477                         flattened_result, op_exec_info.name)) {
   2478       return false;
   2479     }
   2480   }
   2481 
   2482   if (op_exec_info.run_post_exec_callbacks) {
   2483     tensorflow::Safe_PyObjectPtr callback_args(
   2484         Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
   2485                       flattened_result, op_exec_info.name));
   2486     for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
   2487       PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
   2488       if (!PyCallable_Check(callback_fn)) {
   2489         PyErr_SetString(
   2490             PyExc_TypeError,
   2491             Printf("expected a function for "
   2492                    "post execution callback in index %ld, got %s instead",
   2493                    i, callback_fn->ob_type->tp_name)
   2494                 .c_str());
   2495         return false;
   2496       }
   2497       PyObject* callback_result =
   2498           PyObject_CallObject(callback_fn, callback_args.get());
   2499       if (!callback_result) {
   2500         return false;
   2501       }
   2502       Py_DECREF(callback_result);
   2503     }
   2504   }
   2505 
   2506   return true;
   2507 }
   2508 
   2509 }  // namespace
   2510 
   2511 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
   2512   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
   2513   if (args_size < kFastPathExecuteInputStartIndex) {
   2514     PyErr_SetString(
   2515         PyExc_ValueError,
   2516         Printf("There must be at least %d items in the input tuple.",
   2517                kFastPathExecuteInputStartIndex)
   2518             .c_str());
   2519     return nullptr;
   2520   }
   2521 
   2522   FastPathOpExecInfo op_exec_info;
   2523 
   2524   op_exec_info.ctx = reinterpret_cast<TFE_Context*>(
   2525       PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
   2526   op_exec_info.args = args;
   2527 
   2528   if (op_exec_info.ctx == nullptr) {
   2529     // The context hasn't been initialized. It will be in the slow path.
   2530     RaiseFallbackException(
   2531         "This function does not handle the case of the path where "
   2532         "all inputs are not already EagerTensors.");
   2533     return nullptr;
   2534   }
   2535 
   2536   op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
   2537   op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
   2538   op_exec_info.op_def = GetOpDef(op_exec_info.op_name);
   2539   if (op_exec_info.op_def == nullptr) return nullptr;
   2540   op_exec_info.name = PyTuple_GET_ITEM(args, 3);
   2541   op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
   2542 
   2543   const tensorflow::OpDef* op_def = op_exec_info.op_def;
   2544 
   2545   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
   2546   // (similar to benchmark_tf_gradient_function_*). Also consider using an
   2547   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
   2548   // point out problems with heap allocs.
   2549   op_exec_info.run_gradient_callback =
   2550       !*ThreadTapeIsStopped() && !GetTapeSet()->empty();
   2551   op_exec_info.run_post_exec_callbacks =
   2552       op_exec_info.callbacks != Py_None &&
   2553       PyList_Size(op_exec_info.callbacks) > 0;
   2554   op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
   2555                                op_exec_info.run_post_exec_callbacks;
   2556 
   2557   if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
   2558     PyErr_SetString(
   2559         PyExc_ValueError,
   2560         Printf("Tuple size smaller than intended. Expected to be at least %d, "
   2561                "was %ld",
   2562                kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
   2563                args_size)
   2564             .c_str());
   2565     return nullptr;
   2566   }
   2567 
   2568   if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
   2569     RaiseFallbackException(
   2570         "This function does not handle the case of the path where "
   2571         "all inputs are not already EagerTensors.");
   2572     return nullptr;
   2573   }
   2574 
   2575   op_exec_info.attr_to_inputs_map = GetAttrToInputsMap(*op_def);
   2576 
   2577   TF_Status* status = TF_NewStatus();
   2578   TFE_Op* op = TFE_NewOp(op_exec_info.ctx, op_def->name().c_str(), status);
   2579   auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
   2580     TF_DeleteStatus(status);
   2581     TFE_DeleteOp(op);
   2582   });
   2583   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2584     return nullptr;
   2585   }
   2586 
   2587   // Mapping of attr name to size - used to calculate the number of values
   2588   // to be expected by the TFE_Execute run.
   2589   tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
   2590 
   2591   // Set non-inferred attrs, including setting defaults if the attr is passed in
   2592   // as None.
   2593   for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
   2594        i < args_size; i += 2) {
   2595     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
   2596     const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name));
   2597     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
   2598 
   2599     // Not creating an index since most of the time there are not more than a
   2600     // few attrs.
   2601     // TODO(nareshmodi): Maybe include the index as part of the
   2602     // OpRegistrationData.
   2603     for (const auto& attr : op_def->attr()) {
   2604       if (attr_name == attr.name()) {
   2605         SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr_name.data(),
   2606                               py_attr_value, &attr_list_sizes, status);
   2607 
   2608         if (TF_GetCode(status) != TF_OK) {
   2609           VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
   2610                   << "\" since we are unable to set the value for attr \""
   2611                   << attr.name() << "\" due to: " << TF_Message(status);
   2612           RaiseFallbackException(TF_Message(status));
   2613           return nullptr;
   2614         }
   2615 
   2616         break;
   2617       }
   2618     }
   2619   }
   2620 
   2621   TFE_OpSetDevice(op, op_exec_info.device_name, status);
   2622   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2623     return nullptr;
   2624   }
   2625 
   2626   // Flat attrs and inputs as required by the record_gradient call. The attrs
   2627   // here only contain inferred attrs (non-inferred attrs are added directly
   2628   // from the input args).
   2629   // All items in flattened_attrs and flattened_inputs contain
   2630   // Safe_PyObjectPtr - any time something steals a reference to this, it must
   2631   // INCREF.
   2632   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
   2633   // directly.
   2634   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
   2635       nullptr;
   2636   std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
   2637       nullptr;
   2638 
   2639   // TODO(nareshmodi): Encapsulate callbacks information into a struct.
   2640   if (op_exec_info.run_callbacks) {
   2641     flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
   2642     flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
   2643   }
   2644 
   2645   // Add inferred attrs and inputs.
   2646   // The following code might set duplicate type attrs. This will result in
   2647   // the CacheKey for the generated AttrBuilder possibly differing from
   2648   // those where the type attrs are correctly set. Inconsistent CacheKeys
   2649   // for ops means that there might be unnecessarily duplicated kernels.
   2650   // TODO(nareshmodi): Fix this.
   2651   for (int i = 0; i < op_def->input_arg_size(); i++) {
   2652     const auto& input_arg = op_def->input_arg(i);
   2653 
   2654     PyObject* input =
   2655         PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
   2656     if (!input_arg.number_attr().empty()) {
   2657       // The item is a homogeneous list.
   2658       if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
   2659       tensorflow::Safe_PyObjectPtr fast_input(
   2660           PySequence_Fast(input, "Could not parse sequence."));
   2661       if (fast_input.get() == nullptr) {
   2662         return nullptr;
   2663       }
   2664       Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
   2665 
   2666       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
   2667       if (op_exec_info.run_callbacks) {
   2668         flattened_attrs->emplace_back(
   2669             GetPythonObjectFromString(input_arg.number_attr().data()));
   2670         flattened_attrs->emplace_back(PyLong_FromLong(len));
   2671       }
   2672       attr_list_sizes[input_arg.number_attr()] = len;
   2673 
   2674       if (len > 0) {
   2675         // First item adds the type attr.
   2676         if (!AddInputToOp(&op_exec_info,
   2677                           PySequence_Fast_GET_ITEM(fast_input.get(), 0), true,
   2678                           input_arg, flattened_attrs.get(),
   2679                           flattened_inputs.get(), op, status)) {
   2680           return nullptr;
   2681         }
   2682 
   2683         for (Py_ssize_t j = 1; j < len; j++) {
   2684           // Since the list is homogeneous, we don't need to re-add the attr.
   2685           if (!AddInputToOp(&op_exec_info,
   2686                             PySequence_Fast_GET_ITEM(fast_input.get(), j),
   2687                             false, input_arg, nullptr /* flattened_attrs */,
   2688                             flattened_inputs.get(), op, status)) {
   2689             return nullptr;
   2690           }
   2691         }
   2692       }
   2693     } else if (!input_arg.type_list_attr().empty()) {
   2694       // The item is a heterogeneous list.
   2695       if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
   2696         return nullptr;
   2697       }
   2698       const string& attr_name = input_arg.type_list_attr();
   2699       Py_ssize_t len = PySequence_Fast_GET_SIZE(input);
   2700       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
   2701       PyObject* py_attr_value = nullptr;
   2702       if (op_exec_info.run_callbacks) {
   2703         py_attr_value = PyTuple_New(len);
   2704       }
   2705       for (Py_ssize_t j = 0; j < len; j++) {
   2706         PyObject* py_input = PySequence_Fast_GET_ITEM(input, j);
   2707         tensorflow::Safe_PyObjectPtr py_eager_tensor;
   2708         if (!ConvertToTensor(
   2709                 op_exec_info, py_input, &py_eager_tensor,
   2710                 []() { Py_RETURN_NONE; }, [](const TF_DataType& dtype) {},
   2711                 status)) {
   2712           return nullptr;
   2713         }
   2714 
   2715         TFE_TensorHandle* input_handle =
   2716             EagerTensor_Handle(py_eager_tensor.get());
   2717 
   2718         attr_value[j] = TFE_TensorHandleDataType(input_handle);
   2719 
   2720         TFE_OpAddInput(op, input_handle, status);
   2721         if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   2722           return nullptr;
   2723         }
   2724 
   2725         if (op_exec_info.run_callbacks) {
   2726           flattened_inputs->emplace_back(std::move(py_eager_tensor));
   2727 
   2728           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
   2729         }
   2730       }
   2731       if (op_exec_info.run_callbacks) {
   2732         flattened_attrs->emplace_back(
   2733             GetPythonObjectFromString(attr_name.data()));
   2734         flattened_attrs->emplace_back(py_attr_value);
   2735       }
   2736       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
   2737                             attr_value.size());
   2738       attr_list_sizes[attr_name] = len;
   2739     } else {
   2740       // The item is a single item.
   2741       if (!AddInputToOp(&op_exec_info, input, true, input_arg,
   2742                         flattened_attrs.get(), flattened_inputs.get(), op,
   2743                         status)) {
   2744         return nullptr;
   2745       }
   2746     }
   2747   }
   2748 
   2749   int num_retvals = 0;
   2750   for (int i = 0; i < op_def->output_arg_size(); i++) {
   2751     const auto& output_arg = op_def->output_arg(i);
   2752     int delta = 1;
   2753     if (!output_arg.number_attr().empty()) {
   2754       delta = attr_list_sizes[output_arg.number_attr()];
   2755     } else if (!output_arg.type_list_attr().empty()) {
   2756       delta = attr_list_sizes[output_arg.type_list_attr()];
   2757     }
   2758     if (delta < 0) {
   2759       RaiseFallbackException(
   2760           "Attributes suggest that the size of an output list is less than 0");
   2761       return nullptr;
   2762     }
   2763     num_retvals += delta;
   2764   }
   2765 
   2766   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
   2767 
   2768   Py_BEGIN_ALLOW_THREADS;
   2769   TFE_Execute(op, retvals.data(), &num_retvals, status);
   2770   Py_END_ALLOW_THREADS;
   2771 
   2772   if (TF_GetCode(status) != TF_OK) {
   2773     // Augment the status with the op_name for easier debugging similar to
   2774     // TFE_Py_Execute.
   2775     TF_SetStatus(status, TF_GetCode(status),
   2776                  tensorflow::strings::StrCat(
   2777                      TF_Message(status),
   2778                      " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]")
   2779                      .c_str());
   2780 
   2781     MaybeRaiseExceptionFromTFStatus(status, nullptr);
   2782     return nullptr;
   2783   }
   2784 
   2785   tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
   2786   for (int i = 0; i < num_retvals; ++i) {
   2787     PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
   2788   }
   2789 
   2790   if (!RunCallbacks(op_exec_info, args, flattened_inputs.get(),
   2791                     flattened_attrs.get(), flat_result.get())) {
   2792     return nullptr;
   2793   }
   2794 
   2795   // Unflatten results.
   2796   if (op_def->output_arg_size() == 0) {
   2797     Py_RETURN_NONE;
   2798   }
   2799 
   2800   if (op_def->output_arg_size() == 1) {
   2801     if (!op_def->output_arg(0).number_attr().empty() ||
   2802         !op_def->output_arg(0).type_list_attr().empty()) {
   2803       return flat_result.release();
   2804     } else {
   2805       auto* result = PyList_GET_ITEM(flat_result.get(), 0);
   2806       Py_INCREF(result);
   2807       return result;
   2808     }
   2809   }
   2810 
   2811   // Correctly output the results that are made into a namedtuple.
   2812   PyObject* result = PyList_New(op_def->output_arg_size());
   2813   int flat_result_index = 0;
   2814   for (int i = 0; i < op_def->output_arg_size(); i++) {
   2815     if (!op_def->output_arg(i).number_attr().empty()) {
   2816       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
   2817       PyObject* inner_list = PyList_New(list_length);
   2818       for (int j = 0; j < list_length; j++) {
   2819         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
   2820         Py_INCREF(obj);
   2821         PyList_SET_ITEM(inner_list, j, obj);
   2822       }
   2823       PyList_SET_ITEM(result, i, inner_list);
   2824     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
   2825       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
   2826       PyObject* inner_list = PyList_New(list_length);
   2827       for (int j = 0; j < list_length; j++) {
   2828         PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
   2829         Py_INCREF(obj);
   2830         PyList_SET_ITEM(inner_list, j, obj);
   2831       }
   2832       PyList_SET_ITEM(result, i, inner_list);
   2833     } else {
   2834       PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
   2835       Py_INCREF(obj);
   2836       PyList_SET_ITEM(result, i, obj);
   2837     }
   2838   }
   2839   return result;
   2840 }
   2841 
   2842 PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
   2843                                 PyObject* attrs, PyObject* results,
   2844                                 PyObject* name) {
   2845   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
   2846     Py_RETURN_NONE;
   2847   }
   2848 
   2849   return RecordGradient(op_name, inputs, attrs, results, name);
   2850 }
   2851 
   2852 namespace {
   2853 const char kTensor[] = "T";
   2854 const char kIndexedSlices[] = "I";
   2855 const char kList[] = "L";
   2856 const char kListEnd[] = "l";
   2857 const char kTuple[] = "U";
   2858 const char kTupleEnd[] = "u";
   2859 const char kDict[] = "D";
   2860 const char kRaw[] = "R";
   2861 const char kShape[] = "s";
   2862 const char kShapeDelim[] = "-";
   2863 const char kDType[] = "d";
   2864 const char kNone[] = "n";
   2865 
   2866 struct EncodeResult {
   2867   string str;
   2868   std::vector<PyObject*> objects;
   2869 
   2870   PyObject* ToPyTuple() {
   2871     PyObject* result = PyTuple_New(2);
   2872 
   2873     PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
   2874 
   2875     if (objects.empty()) {
   2876       Py_INCREF(Py_None);
   2877       PyTuple_SET_ITEM(result, 1, Py_None);
   2878     } else {
   2879       PyObject* objects_tuple = PyTuple_New(objects.size());
   2880 
   2881       for (int i = 0; i < objects.size(); i++) {
   2882         PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
   2883       }
   2884 
   2885       PyTuple_SET_ITEM(result, 1, objects_tuple);
   2886     }
   2887 
   2888     return result;
   2889   }
   2890 };
   2891 
   2892 tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
   2893                                        bool include_tensor_ranks_only,
   2894                                        EncodeResult* result) {
   2895   if (EagerTensor_CheckExact(arg)) {
   2896     TFE_TensorHandle* t = EagerTensor_Handle(arg);
   2897     tensorflow::TensorShape tensor_shape;
   2898     TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
   2899 
   2900     absl::StrAppend(&result->str, kDType, t->handle->dtype);
   2901 
   2902     absl::StrAppend(&result->str, kShape);
   2903     if (include_tensor_ranks_only) {
   2904       absl::StrAppend(&result->str, tensor_shape.dim_sizes().size());
   2905     } else {
   2906       for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
   2907         absl::StrAppend(&result->str, dim_size, kShapeDelim);
   2908       }
   2909     }
   2910     return tensorflow::Status::OK();
   2911   }
   2912 
   2913   tensorflow::Safe_PyObjectPtr dtype_object(
   2914       PyObject_GetAttrString(arg, "dtype"));
   2915 
   2916   if (dtype_object == nullptr) {
   2917     return tensorflow::errors::InvalidArgument(
   2918         "ops.Tensor object doesn't have dtype() attr.");
   2919   }
   2920 
   2921   tensorflow::Safe_PyObjectPtr dtype_enum(
   2922       PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
   2923 
   2924   if (dtype_enum == nullptr) {
   2925     return tensorflow::errors::InvalidArgument(
   2926         "ops.Tensor's dtype object doesn't have _type_enum() attr.");
   2927   }
   2928 
   2929   tensorflow::DataType dtype =
   2930       static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
   2931 
   2932   absl::StrAppend(&result->str, kDType, dtype);
   2933 
   2934   static char _shape_tuple[] = "_shape_tuple";
   2935   tensorflow::Safe_PyObjectPtr shape_tuple(
   2936       PyObject_CallMethod(arg, _shape_tuple, nullptr));
   2937 
   2938   if (shape_tuple == nullptr) {
   2939     return tensorflow::errors::InvalidArgument(
   2940         "ops.Tensor object doesn't have _shape_tuple() method.");
   2941   }
   2942 
   2943   if (shape_tuple.get() == Py_None) {
   2944     // Unknown shape, encode that directly.
   2945     absl::StrAppend(&result->str, kNone);
   2946     return tensorflow::Status::OK();
   2947   }
   2948 
   2949   absl::StrAppend(&result->str, kShape);
   2950   tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
   2951       shape_tuple.get(), "shape_tuple didn't return a sequence"));
   2952 
   2953   int len = PySequence_Fast_GET_SIZE(shape_seq.get());
   2954 
   2955   if (include_tensor_ranks_only) {
   2956     absl::StrAppend(&result->str, len);
   2957   } else {
   2958     for (int i = 0; i < len; ++i) {
   2959       PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
   2960       if (item == Py_None) {
   2961         absl::StrAppend(&result->str, kNone);
   2962       } else {
   2963         absl::StrAppend(&result->str, MakeInt(item));
   2964       }
   2965     }
   2966   }
   2967   return tensorflow::Status::OK();
   2968 }
   2969 
   2970 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
   2971                                           bool include_tensor_ranks_only,
   2972                                           EncodeResult* result);
   2973 
   2974 // This function doesn't set the type of sequence before
   2975 tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
   2976                                          const char* end_type,
   2977                                          bool include_tensor_ranks_only,
   2978                                          EncodeResult* result) {
   2979   tensorflow::Safe_PyObjectPtr arg_seq(
   2980       PySequence_Fast(arg, "unable to create seq from list/tuple"));
   2981 
   2982   absl::StrAppend(&result->str, type);
   2983   int len = PySequence_Fast_GET_SIZE(arg_seq.get());
   2984   for (int i = 0; i < len; ++i) {
   2985     PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
   2986     if (item == Py_None) {
   2987       absl::StrAppend(&result->str, kNone);
   2988     } else {
   2989       TF_RETURN_IF_ERROR(
   2990           TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
   2991     }
   2992   }
   2993   absl::StrAppend(&result->str, end_type);
   2994 
   2995   return tensorflow::Status::OK();
   2996 }
   2997 
   2998 tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
   2999                                           bool include_tensor_ranks_only,
   3000                                           EncodeResult* result) {
   3001   if (tensorflow::swig::IsTensor(arg)) {
   3002     absl::StrAppend(&result->str, kTensor);
   3003     TF_RETURN_IF_ERROR(
   3004         TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
   3005   } else if (tensorflow::swig::IsIndexedSlices(arg)) {
   3006     absl::StrAppend(&result->str, kIndexedSlices);
   3007     tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
   3008     if (values == nullptr) {
   3009       PyErr_Clear();
   3010       return tensorflow::errors::InvalidArgument(
   3011           "IndexedSlices does not have a values attr");
   3012     }
   3013     TF_RETURN_IF_ERROR(
   3014         TFE_Py_EncodeTensor(values.get(), include_tensor_ranks_only, result));
   3015 
   3016     tensorflow::Safe_PyObjectPtr indices(
   3017         PyObject_GetAttrString(arg, "indices"));
   3018     if (indices == nullptr) {
   3019       PyErr_Clear();
   3020       return tensorflow::errors::InvalidArgument(
   3021           "IndexedSlices does not have a indices attr");
   3022     }
   3023     TF_RETURN_IF_ERROR(
   3024         TFE_Py_EncodeTensor(indices.get(), include_tensor_ranks_only, result));
   3025 
   3026     tensorflow::Safe_PyObjectPtr dense_shape(
   3027         PyObject_GetAttrString(arg, "dense_shape"));
   3028     if (dense_shape == nullptr) {
   3029       PyErr_Clear();
   3030       return tensorflow::errors::InvalidArgument(
   3031           "IndexedSlices does not have a dense_shape attr");
   3032     }
   3033     if (dense_shape.get() != Py_None) {
   3034       TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(
   3035           dense_shape.get(), include_tensor_ranks_only, result));
   3036     }
   3037   } else if (PyList_Check(arg)) {
   3038     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
   3039         arg, kList, kListEnd, include_tensor_ranks_only, result));
   3040   } else if (PyTuple_Check(arg)) {
   3041     TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
   3042         arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
   3043   } else if (PyDict_Check(arg)) {
   3044     tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
   3045     if (PyList_Sort(keys.get()) == -1) {
   3046       return tensorflow::errors::Internal("Unable to sort keys");
   3047     }
   3048 
   3049     absl::StrAppend(&result->str, kDict);
   3050     int len = PyList_Size(keys.get());
   3051 
   3052     for (int i = 0; i < len; i++) {
   3053       PyObject* key = PyList_GetItem(keys.get(), i);
   3054       TF_RETURN_IF_ERROR(
   3055           TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
   3056       PyObject* value = PyDict_GetItem(arg, key);
   3057       TF_RETURN_IF_ERROR(
   3058           TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
   3059     }
   3060   } else {
   3061     PyObject* object = PyWeakref_NewRef(arg, nullptr);
   3062 
   3063     if (object == nullptr) {
   3064       PyErr_Clear();
   3065 
   3066       object = arg;
   3067       Py_INCREF(object);
   3068     }
   3069 
   3070     absl::StrAppend(&result->str, kRaw);
   3071     result->objects.push_back(object);
   3072   }
   3073 
   3074   return tensorflow::Status::OK();
   3075 }
   3076 
   3077 }  // namespace
   3078 
   3079 // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
   3080 // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
   3081 // are used for both performance reasons, as much TensorFlow code specializes
   3082 // on known shapes to produce slimmer graphs, and correctness, as some
   3083 // high-level APIs require shapes to be fully-known.
   3084 //
   3085 // `include_tensor_ranks_only` allows caching on arguments excluding shape info,
   3086 // so that a slow path using relaxed shape can rely on a cache key that excludes
   3087 // shapes.
   3088 //
   3089 // TODO(nareshmodi): Add support for sparse tensors.
   3090 PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
   3091   EncodeResult result;
   3092   const auto status =
   3093       TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
   3094   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
   3095     return nullptr;
   3096   }
   3097 
   3098   return result.ToPyTuple();
   3099 }
   3100 
   3101 // A method prints incoming messages directly to Python's
   3102 // stdout using Python's C API. This is necessary in Jupyter notebooks
   3103 // and colabs where messages to the C stdout don't go to the notebook
   3104 // cell outputs, but calls to Python's stdout do.
   3105 void PrintToPythonStdout(const char* msg) {
   3106   if (Py_IsInitialized()) {
   3107     PyGILState_STATE py_threadstate;
   3108     py_threadstate = PyGILState_Ensure();
   3109 
   3110     string string_msg = msg;
   3111     // PySys_WriteStdout truncates strings over 1000 bytes, so
   3112     // we write the message in chunks small enough to not be truncated.
   3113     int CHUNK_SIZE = 900;
   3114     auto len = string_msg.length();
   3115     for (int i = 0; i < len; i += CHUNK_SIZE) {
   3116       PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
   3117     }
   3118     PySys_WriteStdout("\n");
   3119 
   3120     PyGILState_Release(py_threadstate);
   3121   }
   3122 }
   3123 
   3124 // Register PrintToPythonStdout as a log listener, to allow
   3125 // printing in colabs and jupyter notebooks to work.
   3126 void TFE_Py_EnableInteractivePythonLogging() {
   3127   static bool enabled_interactive_logging = false;
   3128   if (!enabled_interactive_logging) {
   3129     enabled_interactive_logging = true;
   3130     TF_RegisterLogListener(PrintToPythonStdout);
   3131   }
   3132 }
   3133