Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <thread>
     17 
     18 #include "tensorflow/python/eager/pywrap_tfe.h"
     19 
     20 #include "tensorflow/c/c_api.h"
     21 #include "tensorflow/c/c_api_internal.h"
     22 #include "tensorflow/c/eager/c_api_internal.h"
     23 #include "tensorflow/c/eager/tape.h"
     24 #include "tensorflow/core/lib/gtl/cleanup.h"
     25 #include "tensorflow/core/lib/gtl/compactptrset.h"
     26 #include "tensorflow/core/lib/gtl/flatmap.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/lib/strings/stringprintf.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/protobuf.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/python/eager/pywrap_tensor.h"
     33 
     34 using tensorflow::string;
     35 using tensorflow::strings::Printf;
     36 
     37 namespace {
     38 
     39 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn)                       \
     40   bool fn_name(const string& key, PyObject* py_value, TF_Status* status,     \
     41                type* value) {                                                \
     42     if (check_fn(py_value)) {                                                \
     43       *value = static_cast<type>(parse_fn(py_value));                        \
     44       return true;                                                           \
     45     } else {                                                                 \
     46       TF_SetStatus(status, TF_INVALID_ARGUMENT,                              \
     47                    tensorflow::strings::StrCat(                              \
     48                        "Expecting " #type " value for attr ", key, ", got ", \
     49                        py_value->ob_type->tp_name)                           \
     50                        .c_str());                                            \
     51       return false;                                                          \
     52     }                                                                        \
     53   }
     54 
     55 #if PY_MAJOR_VERSION >= 3
     56 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
     57 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
     58 #else
     59 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
     60 PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
     61 PARSE_VALUE(ParseInt64LongValue, int64_t, PyLong_Check, PyLong_AsLong)
     62 #endif
     63 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
     64 #undef PARSE_VALUE
     65 
     66 Py_ssize_t TensorShapeNumDims(PyObject* value) {
     67   const auto size = PySequence_Size(value);
     68   if (size == -1) {
     69     // TensorShape.__len__ raises an error in the scenario where the shape is an
     70     // unknown, which needs to be cleared.
     71     // TODO(nareshmodi): ensure that this is actually a TensorShape.
     72     PyErr_Clear();
     73   }
     74   return size;
     75 }
     76 
     77 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
     78                       const char** value) {
     79   if (PyBytes_Check(py_value)) {
     80     *value = PyBytes_AsString(py_value);
     81     return true;
     82   }
     83 #if PY_MAJOR_VERSION >= 3
     84   if (PyUnicode_Check(py_value)) {
     85     *value = PyUnicode_AsUTF8(py_value);
     86     return true;
     87   }
     88 #endif
     89   TF_SetStatus(
     90       status, TF_INVALID_ARGUMENT,
     91       tensorflow::strings::StrCat("Expecting a string value for attr ", key,
     92                                   ", got ", py_value->ob_type->tp_name)
     93           .c_str());
     94   return false;
     95 }
     96 
     97 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
     98                     unsigned char* value) {
     99   *value = PyObject_IsTrue(py_value);
    100   return true;
    101 }
    102 
    103 bool IsInteger(PyObject* py_value) {
    104 #if PY_MAJOR_VERSION >= 3
    105   return PyLong_Check(py_value);
    106 #else
    107   return PyInt_Check(py_value);
    108 #endif
    109 }
    110 
    111 // The passed in py_value is expected to be an object of the python type
    112 // dtypes.DType or an int.
    113 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
    114                     int* value) {
    115   if (IsInteger(py_value)) {
    116     return ParseIntValue(key, py_value, status, value);
    117   }
    118 
    119   PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum");
    120   if (py_type_enum == nullptr) {
    121     return false;
    122   }
    123 
    124   if (!ParseIntValue(key, py_type_enum, status, value)) {
    125     Py_DECREF(py_type_enum);
    126     return false;
    127   }
    128 
    129   Py_DECREF(py_type_enum);
    130   return true;
    131 }
    132 
    133 bool SetOpAttrList(
    134     TFE_Op* op, const char* key, PyObject* py_list, TF_AttrType type,
    135     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    136     TF_Status* status) {
    137   if (!PySequence_Check(py_list)) {
    138     TF_SetStatus(
    139         status, TF_INVALID_ARGUMENT,
    140         tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
    141                                     ", got ", py_list->ob_type->tp_name)
    142             .c_str());
    143     return false;
    144   }
    145   const int num_values = PySequence_Size(py_list);
    146   if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
    147 
    148 #define PARSE_LIST(c_type, parse_fn)                                \
    149   std::unique_ptr<c_type[]> values(new c_type[num_values]);         \
    150   for (int i = 0; i < num_values; ++i) {                            \
    151     auto py_value = PySequence_ITEM(py_list, i);                    \
    152     if (!parse_fn(key, py_value, status, &values[i])) return false; \
    153   }
    154 
    155   if (type == TF_ATTR_STRING) {
    156     PARSE_LIST(const char*, ParseStringValue);
    157     TFE_OpSetAttrStringList(op, key, values.get(), num_values);
    158   } else if (type == TF_ATTR_INT) {
    159     PARSE_LIST(int64_t, ParseInt64Value);
    160     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
    161   } else if (type == TF_ATTR_FLOAT) {
    162     PARSE_LIST(float, ParseFloatValue);
    163     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
    164   } else if (type == TF_ATTR_BOOL) {
    165     PARSE_LIST(unsigned char, ParseBoolValue);
    166     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
    167   } else if (type == TF_ATTR_TYPE) {
    168     PARSE_LIST(int, ParseTypeValue);
    169     TFE_OpSetAttrTypeList(op, key,
    170                           reinterpret_cast<const TF_DataType*>(values.get()),
    171                           num_values);
    172   } else if (type == TF_ATTR_SHAPE) {
    173     // Make one pass through the input counting the total number of
    174     // dims across all the input lists.
    175     int total_dims = 0;
    176     for (int i = 0; i < num_values; ++i) {
    177       auto py_value = PySequence_ITEM(py_list, i);
    178       if (py_value != Py_None) {
    179         if (!PySequence_Check(py_value)) {
    180           TF_SetStatus(
    181               status, TF_INVALID_ARGUMENT,
    182               tensorflow::strings::StrCat(
    183                   "Expecting None or sequence value for element", i,
    184                   " of attr ", key, ", got ", py_value->ob_type->tp_name)
    185                   .c_str());
    186           return false;
    187         }
    188         const auto size = TensorShapeNumDims(py_value);
    189         if (size >= 0) {
    190           total_dims += size;
    191         }
    192       }
    193     }
    194     // Allocate a buffer that can fit all of the dims together.
    195     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
    196     // Copy the input dims into the buffer and set dims to point to
    197     // the start of each list's dims.
    198     std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
    199     std::unique_ptr<int[]> num_dims(new int[num_values]);
    200     int64_t* offset = buffer.get();
    201     for (int i = 0; i < num_values; ++i) {
    202       auto py_value = PySequence_ITEM(py_list, i);
    203       if (py_value == Py_None) {
    204         dims[i] = nullptr;
    205         num_dims[i] = -1;
    206       } else {
    207         const auto size = TensorShapeNumDims(py_value);
    208         if (size == -1) {
    209           dims[i] = nullptr;
    210           num_dims[i] = -1;
    211           continue;
    212         }
    213         dims[i] = offset;
    214         num_dims[i] = size;
    215         for (int j = 0; j < size; ++j) {
    216           auto inner_py_value = PySequence_ITEM(py_value, j);
    217           if (inner_py_value == Py_None) {
    218             *offset = -1;
    219           } else if (!ParseInt64Value(key, inner_py_value, status, offset)) {
    220             return false;
    221           }
    222           ++offset;
    223         }
    224       }
    225     }
    226     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
    227                            status);
    228     if (TF_GetCode(status) != TF_OK) return false;
    229   } else {
    230     TF_SetStatus(status, TF_UNIMPLEMENTED,
    231                  tensorflow::strings::StrCat("Attr ", key,
    232                                              " has unhandled list type ", type)
    233                      .c_str());
    234     return false;
    235   }
    236 #undef PARSE_LIST
    237   return true;
    238 }
    239 
    240 // This is only declared here since GetFunc makes a recursive call to
    241 // SetOpAttrScalarDefault.
    242 void SetOpAttrScalarDefault(
    243     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
    244     const char* attr_name,
    245     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    246     TF_Status* status);
    247 
    248 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
    249                 TF_Status* status) {
    250   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
    251   for (const auto& attr : func.attr()) {
    252     if (TF_GetCode(status) != TF_OK) return nullptr;
    253     SetOpAttrScalarDefault(ctx, func_op, attr.second, attr.first.data(),
    254                            nullptr, status);
    255     if (TF_GetCode(status) != TF_OK) return nullptr;
    256   }
    257   return func_op;
    258 }
    259 
    260 void SetOpAttrListDefault(
    261     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
    262     const char* key, TF_AttrType type,
    263     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    264     TF_Status* status) {
    265   if (type == TF_ATTR_STRING) {
    266     int num_values = attr.default_value().list().s_size();
    267     std::unique_ptr<const char* []> values(new const char*[num_values]);
    268     (*attr_list_sizes)[key] = num_values;
    269     for (int i = 0; i < num_values; i++) {
    270       values[i] = attr.default_value().list().s(i).data();
    271     }
    272     TFE_OpSetAttrStringList(op, key, values.get(), num_values);
    273   } else if (type == TF_ATTR_INT) {
    274     int num_values = attr.default_value().list().i_size();
    275     std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
    276     (*attr_list_sizes)[key] = num_values;
    277     for (int i = 0; i < num_values; i++) {
    278       values[i] = attr.default_value().list().i(i);
    279     }
    280     TFE_OpSetAttrIntList(op, key, values.get(), num_values);
    281   } else if (type == TF_ATTR_FLOAT) {
    282     int num_values = attr.default_value().list().f_size();
    283     std::unique_ptr<float[]> values(new float[num_values]);
    284     (*attr_list_sizes)[key] = num_values;
    285     for (int i = 0; i < num_values; i++) {
    286       values[i] = attr.default_value().list().f(i);
    287     }
    288     TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
    289   } else if (type == TF_ATTR_BOOL) {
    290     int num_values = attr.default_value().list().b_size();
    291     std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
    292     (*attr_list_sizes)[key] = num_values;
    293     for (int i = 0; i < num_values; i++) {
    294       values[i] = attr.default_value().list().b(i);
    295     }
    296     TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
    297   } else if (type == TF_ATTR_TYPE) {
    298     int num_values = attr.default_value().list().type_size();
    299     std::unique_ptr<int[]> values(new int[num_values]);
    300     (*attr_list_sizes)[key] = num_values;
    301     for (int i = 0; i < num_values; i++) {
    302       values[i] = attr.default_value().list().type(i);
    303     }
    304     TFE_OpSetAttrTypeList(op, key,
    305                           reinterpret_cast<const TF_DataType*>(values.get()),
    306                           attr.default_value().list().type_size());
    307   } else if (type == TF_ATTR_SHAPE) {
    308     int num_values = attr.default_value().list().shape_size();
    309     (*attr_list_sizes)[key] = num_values;
    310     int total_dims = 0;
    311     for (int i = 0; i < num_values; ++i) {
    312       if (!attr.default_value().list().shape(i).unknown_rank()) {
    313         total_dims += attr.default_value().list().shape(i).dim_size();
    314       }
    315     }
    316     // Allocate a buffer that can fit all of the dims together.
    317     std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
    318     // Copy the input dims into the buffer and set dims to point to
    319     // the start of each list's dims.
    320     std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]);
    321     std::unique_ptr<int[]> num_dims(new int[num_values]);
    322     int64_t* offset = buffer.get();
    323     for (int i = 0; i < num_values; ++i) {
    324       const auto& shape = attr.default_value().list().shape(i);
    325       if (shape.unknown_rank()) {
    326         dims[i] = nullptr;
    327         num_dims[i] = -1;
    328       } else {
    329         for (int j = 0; j < shape.dim_size(); j++) {
    330           *offset = shape.dim(j).size();
    331           ++offset;
    332         }
    333       }
    334     }
    335     TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
    336                            status);
    337   } else if (type == TF_ATTR_FUNC) {
    338     int num_values = attr.default_value().list().func_size();
    339     (*attr_list_sizes)[key] = num_values;
    340     std::unique_ptr<const TFE_Op* []> funcs(new const TFE_Op*[num_values]);
    341     for (int i = 0; i < num_values; i++) {
    342       funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
    343     }
    344     TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
    345   } else {
    346     TF_SetStatus(status, TF_UNIMPLEMENTED,
    347                  "Lists of tensors are not yet implemented for default valued "
    348                  "attributes for an operation.");
    349   }
    350 }
    351 
    352 bool SetOpAttrScalar(
    353     TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
    354     TF_AttrType type,
    355     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    356     TF_Status* status) {
    357   if (type == TF_ATTR_STRING) {
    358     const char* value;
    359     if (!ParseStringValue(key, py_value, status, &value)) return false;
    360     TFE_OpSetAttrString(op, key, value);
    361   } else if (type == TF_ATTR_INT) {
    362     int64_t value;
    363     if (!ParseInt64Value(key, py_value, status, &value)) return false;
    364     TFE_OpSetAttrInt(op, key, value);
    365     // attr_list_sizes is set for all int attributes (since at this point we are
    366     // not aware if that attribute might be used to calculate the size of an
    367     // output list or not).
    368     if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
    369   } else if (type == TF_ATTR_FLOAT) {
    370     float value;
    371     if (!ParseFloatValue(key, py_value, status, &value)) return false;
    372     TFE_OpSetAttrFloat(op, key, value);
    373   } else if (type == TF_ATTR_BOOL) {
    374     unsigned char value;
    375     if (!ParseBoolValue(key, py_value, status, &value)) return false;
    376     TFE_OpSetAttrBool(op, key, value);
    377   } else if (type == TF_ATTR_TYPE) {
    378     int value;
    379     if (!ParseTypeValue(key, py_value, status, &value)) return false;
    380     TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
    381   } else if (type == TF_ATTR_SHAPE) {
    382     if (py_value == Py_None) {
    383       TFE_OpSetAttrShape(op, key, nullptr, -1, status);
    384     } else {
    385       if (!PySequence_Check(py_value)) {
    386         TF_SetStatus(status, TF_INVALID_ARGUMENT,
    387                      tensorflow::strings::StrCat(
    388                          "Expecting None or sequence value for attr", key,
    389                          ", got ", py_value->ob_type->tp_name)
    390                          .c_str());
    391         return false;
    392       }
    393       const auto num_dims = TensorShapeNumDims(py_value);
    394       if (num_dims == -1) {
    395         TFE_OpSetAttrShape(op, key, nullptr, -1, status);
    396         return true;
    397       }
    398       std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
    399       for (int i = 0; i < num_dims; ++i) {
    400         auto inner_py_value = PySequence_ITEM(py_value, i);
    401         if (inner_py_value == Py_None) {
    402           dims[i] = -1;
    403         } else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) {
    404           return false;
    405         }
    406       }
    407       TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
    408     }
    409     if (TF_GetCode(status) != TF_OK) return false;
    410   } else if (type == TF_ATTR_FUNC) {
    411     // Allow:
    412     // (1) String function name, OR
    413     // (2) A Python object with a .name attribute
    414     //     (A crude test for being a
    415     //     tensorflow.python.framework.function._DefinedFunction)
    416     //     (which is what the various "defun" or "Defun" decorators do).
    417     // And in the future also allow an object that can encapsulate
    418     // the function name and its attribute values.
    419     const char* func_name = nullptr;
    420     if (!ParseStringValue(key, py_value, status, &func_name)) {
    421       PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
    422       if (name_attr == nullptr ||
    423           !ParseStringValue(key, name_attr, status, &func_name)) {
    424         TF_SetStatus(
    425             status, TF_INVALID_ARGUMENT,
    426             tensorflow::strings::StrCat(
    427                 "unable to set function value attribute from a ",
    428                 py_value->ob_type->tp_name,
    429                 " object. If you think this is an error, please file an issue "
    430                 "at https://github.com/tensorflow/tensorflow/issues/new")
    431                 .c_str());
    432         return false;
    433       }
    434     }
    435     TFE_Op* func = TFE_NewOp(ctx, func_name, status);
    436     if (TF_GetCode(status) != TF_OK) return false;
    437     TFE_OpSetAttrFunction(op, key, func);
    438     TFE_DeleteOp(func);
    439   } else {
    440     TF_SetStatus(
    441         status, TF_UNIMPLEMENTED,
    442         tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
    443             .c_str());
    444     return false;
    445   }
    446   return true;
    447 }
    448 
    449 void SetOpAttrScalarDefault(
    450     TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
    451     const char* attr_name,
    452     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    453     TF_Status* status) {
    454   switch (default_value.value_case()) {
    455     case tensorflow::AttrValue::kS:
    456       TFE_OpSetAttrString(op, attr_name, default_value.s().data());
    457       break;
    458     case tensorflow::AttrValue::kI:
    459       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
    460       (*attr_list_sizes)[attr_name] = default_value.i();
    461       break;
    462     case tensorflow::AttrValue::kF:
    463       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
    464       break;
    465     case tensorflow::AttrValue::kB:
    466       TFE_OpSetAttrBool(op, attr_name, default_value.b());
    467       break;
    468     case tensorflow::AttrValue::kType:
    469       TFE_OpSetAttrType(op, attr_name,
    470                         static_cast<TF_DataType>(default_value.type()));
    471       break;
    472     case tensorflow::AttrValue::kShape: {
    473       const auto& tensor_shape = default_value.shape();
    474       if (tensor_shape.unknown_rank()) {
    475         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
    476       } else {
    477         const auto num_dims = tensor_shape.dim_size();
    478         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
    479         for (int i = 0; i < num_dims; ++i) {
    480           dims[i] = tensor_shape.dim(i).size();
    481         }
    482         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
    483       }
    484     } break;
    485     case tensorflow::AttrValue::kFunc: {
    486       const auto func_op = GetFunc(ctx, default_value.func(), status);
    487       if (TF_GetCode(status) != TF_OK) return;
    488       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
    489       // require TFE_Op* and just convert it internally a NameAttrValue, so
    490       // consider adding an overload to the C API to make this case easier.
    491       TFE_OpSetAttrFunction(op, attr_name, func_op);
    492     } break;
    493     case tensorflow::AttrValue::kList:
    494       TF_FALLTHROUGH_INTENDED;
    495     case tensorflow::AttrValue::kTensor:
    496       TF_FALLTHROUGH_INTENDED;
    497     case tensorflow::AttrValue::kPlaceholder:
    498       TF_FALLTHROUGH_INTENDED;
    499     case tensorflow::AttrValue::VALUE_NOT_SET:
    500       TF_SetStatus(
    501           status, TF_UNIMPLEMENTED,
    502           tensorflow::strings::StrCat("Unable to get setfor default value: ",
    503                                       default_value.DebugString())
    504               .data());
    505   }
    506 }
    507 
    508 // start_index is the index at which the Tuple/List attrs will start getting
    509 // processed.
    510 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
    511                 TF_Status* out_status) {
    512   if (attrs == Py_None) return;
    513   Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
    514   if ((len & 1) != 0) {
    515     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    516                  "Expecting attrs tuple to have even length.");
    517     return;
    518   }
    519   // Parse attrs
    520   for (Py_ssize_t i = 0; i < len; i += 2) {
    521     PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
    522     PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
    523 #if PY_MAJOR_VERSION >= 3
    524     const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
    525                                             : PyUnicode_AsUTF8(py_key);
    526 #else
    527     const char* key = PyBytes_AsString(py_key);
    528 #endif
    529     unsigned char is_list = 0;
    530     const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
    531     if (TF_GetCode(out_status) != TF_OK) return;
    532     if (is_list != 0) {
    533       if (!SetOpAttrList(op, key, py_value, type, nullptr, out_status)) return;
    534     } else {
    535       if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
    536         return;
    537     }
    538   }
    539 }
    540 
    541 // This function will set the op attrs required. If an attr has the value of
    542 // None, then it will read the AttrDef to get the default value and set that
    543 // instead. Any failure in this function will simply fall back to the slow
    544 // path.
    545 void SetOpAttrWithDefaults(
    546     TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
    547     const char* attr_name, PyObject* attr_value,
    548     tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
    549     TF_Status* status) {
    550   unsigned char is_list = 0;
    551   const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
    552   if (TF_GetCode(status) != TF_OK) return;
    553   if (attr_value == Py_None) {
    554     if (is_list != 0) {
    555       SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
    556                            status);
    557     } else {
    558       SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
    559                              attr_list_sizes, status);
    560     }
    561   } else {
    562     if (is_list != 0) {
    563       SetOpAttrList(op, attr_name, attr_value, type, attr_list_sizes, status);
    564     } else {
    565       SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
    566                       status);
    567     }
    568   }
    569 }
    570 
    571 // Python subclass of Exception that is created on not ok Status.
    572 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
    573 PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
    574 
    575 // Python subclass of Exception that is created to signal fallback.
    576 PyObject* fallback_exception_class = nullptr;
    577 
    578 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
    579 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
    580 
    581 }  // namespace
    582 
    583 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
    584                     const char* op_name, TFE_InputTensorHandles* inputs,
    585                     PyObject* attrs, TFE_OutputTensorHandles* outputs,
    586                     TF_Status* out_status) {
    587   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
    588   if (TF_GetCode(out_status) != TF_OK) return;
    589   TFE_OpSetDevice(op, device_name, out_status);
    590   if (TF_GetCode(out_status) == TF_OK) {
    591     for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK;
    592          ++i) {
    593       TFE_OpAddInput(op, inputs->at(i), out_status);
    594     }
    595   }
    596   if (TF_GetCode(out_status) == TF_OK) {
    597     SetOpAttrs(ctx, op, attrs, 0, out_status);
    598   }
    599   Py_BEGIN_ALLOW_THREADS;
    600   if (TF_GetCode(out_status) == TF_OK) {
    601     int num_outputs = outputs->size();
    602     TFE_Execute(op, outputs->data(), &num_outputs, out_status);
    603     outputs->resize(num_outputs);
    604   }
    605   if (TF_GetCode(out_status) != TF_OK) {
    606     TF_SetStatus(out_status, TF_GetCode(out_status),
    607                  tensorflow::strings::StrCat(TF_Message(out_status),
    608                                              " [Op:", op_name, "]")
    609                      .c_str());
    610   }
    611   TFE_DeleteOp(op);
    612   Py_END_ALLOW_THREADS;
    613 }
    614 
    615 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
    616   tensorflow::mutex_lock l(exception_class_mutex);
    617   if (exception_class != nullptr) {
    618     Py_DECREF(exception_class);
    619   }
    620   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
    621     exception_class = nullptr;
    622     PyErr_SetString(PyExc_TypeError,
    623                     "TFE_Py_RegisterExceptionClass: "
    624                     "Registered class should be subclass of Exception.");
    625     return nullptr;
    626   } else {
    627     Py_INCREF(e);
    628     exception_class = e;
    629     Py_RETURN_NONE;
    630   }
    631 }
    632 
    633 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
    634   if (fallback_exception_class != nullptr) {
    635     Py_DECREF(fallback_exception_class);
    636   }
    637   if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
    638     fallback_exception_class = nullptr;
    639     PyErr_SetString(PyExc_TypeError,
    640                     "TFE_Py_RegisterFallbackExceptionClass: "
    641                     "Registered class should be subclass of Exception.");
    642     return nullptr;
    643   } else {
    644     Py_INCREF(e);
    645     fallback_exception_class = e;
    646     Py_RETURN_NONE;
    647   }
    648 }
    649 
    650 void RaiseFallbackException(const char* message) {
    651   if (fallback_exception_class != nullptr) {
    652     PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
    653     return;
    654   }
    655 
    656   PyErr_SetString(
    657       PyExc_RuntimeError,
    658       tensorflow::strings::StrCat(
    659           "Fallback exception type not set, attempting to fallback due to ",
    660           message)
    661           .data());
    662 }
    663 
    664 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
    665   if (TF_GetCode(status) == TF_OK) return 0;
    666   const char* msg = TF_Message(status);
    667   if (exception == nullptr) {
    668     tensorflow::mutex_lock l(exception_class_mutex);
    669     if (exception_class != nullptr) {
    670       PyErr_SetObject(exception_class,
    671                       Py_BuildValue("si", msg, TF_GetCode(status)));
    672       return -1;
    673     } else {
    674       exception = PyExc_RuntimeError;
    675     }
    676   }
    677   // May be update already set exception.
    678   PyErr_SetString(exception, msg);
    679   return -1;
    680 }
    681 
    682 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
    683                                   PyObject* exception) {
    684   if (status.ok()) return 0;
    685   const char* msg = status.error_message().c_str();
    686   if (exception == nullptr) {
    687     tensorflow::mutex_lock l(exception_class_mutex);
    688     if (exception_class != nullptr) {
    689       PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
    690       return -1;
    691     } else {
    692       exception = PyExc_RuntimeError;
    693     }
    694   }
    695   // May be update already set exception.
    696   PyErr_SetString(exception, msg);
    697   return -1;
    698 }
    699 
    700 char* TFE_GetPythonString(PyObject* o) {
    701   if (PyBytes_Check(o)) {
    702     return PyBytes_AsString(o);
    703   }
    704 #if PY_MAJOR_VERSION >= 3
    705   if (PyUnicode_Check(o)) {
    706     return PyUnicode_AsUTF8(o);
    707   }
    708 #endif
    709   return nullptr;
    710 }
    711 
    712 int64_t get_uid() {
    713   tensorflow::mutex_lock l(_uid_mutex);
    714   return _uid++;
    715 }
    716 
    717 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
    718 
    719 void TFE_DeleteContextCapsule(PyObject* context) {
    720   TF_Status* status = TF_NewStatus();
    721   TFE_Context* ctx =
    722       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
    723   TFE_DeleteContext(ctx, status);
    724   TF_DeleteStatus(status);
    725 }
    726 
    727 static tensorflow::int64 MakeInt(PyObject* integer) {
    728 #if PY_MAJOR_VERSION >= 3
    729   return PyLong_AsLong(integer);
    730 #else
    731   return PyInt_AsLong(integer);
    732 #endif
    733 }
    734 
    735 static tensorflow::int64 FastTensorId(PyObject* tensor) {
    736   if (EagerTensor_CheckExact(tensor)) {
    737     return EagerTensor_id(tensor);
    738   }
    739   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
    740   if (id_field == nullptr) {
    741     return -1;
    742   }
    743   tensorflow::int64 id = MakeInt(id_field);
    744   Py_DECREF(id_field);
    745   return id;
    746 }
    747 
    748 class GradientTape
    749     : public tensorflow::eager::GradientTape<PyObject, PyObject> {
    750  public:
    751   explicit GradientTape(bool persistent)
    752       : tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {}
    753 
    754   virtual ~GradientTape() {
    755     for (PyObject* v : watched_variables_) {
    756       Py_DECREF(v);
    757     }
    758   }
    759 
    760   void WatchVariable(PyObject* v) {
    761     auto insert_result = watched_variables_.insert(v);
    762     if (insert_result.second) {
    763       // Only increment the reference count if we aren't already watching this
    764       // variable.
    765       Py_INCREF(v);
    766     }
    767     PyObject* handle = PyObject_GetAttrString(v, "handle");
    768     if (handle == nullptr) {
    769       return;
    770     }
    771     tensorflow::int64 id = FastTensorId(handle);
    772     Py_DECREF(handle);
    773     if (!PyErr_Occurred()) {
    774       this->Watch(id);
    775     }
    776   }
    777 
    778   const std::unordered_set<PyObject*> WatchedVariables() {
    779     return watched_variables_;
    780   }
    781 
    782  private:
    783   std::unordered_set<PyObject*> watched_variables_;
    784 };
    785 
    786 typedef struct {
    787   PyObject_HEAD
    788       /* Type-specific fields go here. */
    789       GradientTape* tape;
    790 } TFE_Py_Tape;
    791 
    792 static void TFE_Py_Tape_Delete(PyObject* tape) {
    793   delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
    794   Py_TYPE(tape)->tp_free(tape);
    795 }
    796 
    797 static PyTypeObject TFE_Py_Tape_Type = {
    798     PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
    799     sizeof(TFE_Py_Tape),                          /* tp_basicsize */
    800     0,                                            /* tp_itemsize */
    801     &TFE_Py_Tape_Delete,                          /* tp_dealloc */
    802     nullptr,                                      /* tp_print */
    803     nullptr,                                      /* tp_getattr */
    804     nullptr,                                      /* tp_setattr */
    805     nullptr,                                      /* tp_reserved */
    806     nullptr,                                      /* tp_repr */
    807     nullptr,                                      /* tp_as_number */
    808     nullptr,                                      /* tp_as_sequence */
    809     nullptr,                                      /* tp_as_mapping */
    810     nullptr,                                      /* tp_hash  */
    811     nullptr,                                      /* tp_call */
    812     nullptr,                                      /* tp_str */
    813     nullptr,                                      /* tp_getattro */
    814     nullptr,                                      /* tp_setattro */
    815     nullptr,                                      /* tp_as_buffer */
    816     Py_TPFLAGS_DEFAULT,                           /* tp_flags */
    817     "TFE_Py_Tape objects",                        /* tp_doc */
    818 };
    819 
    820 // Note: in the current design no mutex is needed here because of the python
    821 // GIL, which is always held when any TFE_Py_* methods are called. We should
    822 // revisit this if/when decide to not hold the GIL while manipulating the tape
    823 // stack.
    824 static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr;
    825 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
    826   if (tape_set == nullptr) {
    827     tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
    828   }
    829   return tape_set;
    830 }
    831 
    832 // A safe copy of the current tapeset. Does not get affected by other python
    833 // threads changing the set of active tapes.
    834 class SafeTapeSet {
    835  public:
    836   SafeTapeSet() : tape_set_(*GetTapeSet()) {
    837     for (auto* tape : tape_set_) {
    838       Py_INCREF(tape);
    839     }
    840   }
    841 
    842   ~SafeTapeSet() {
    843     for (auto* tape : tape_set_) {
    844       Py_DECREF(tape);
    845     }
    846   }
    847 
    848   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() {
    849     return tape_set_.begin();
    850   }
    851 
    852   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() {
    853     return tape_set_.end();
    854   }
    855 
    856  private:
    857   tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
    858 };
    859 
    860 // xcode 7 doesn't define thread_local, so for compatibility we implement our
    861 // own. TODO(apassos) remove once we can deprecate xcode 7.
    862 #ifndef __APPLE__
    863 bool* ThreadTapeIsStopped() {
    864   thread_local bool thread_tape_is_stopped{false};
    865   return &thread_tape_is_stopped;
    866 }
    867 #else
    868 static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
    869 bool* ThreadTapeIsStopped() {
    870   if (tape_is_stopped == nullptr) {
    871     tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
    872   }
    873   auto it = tape_is_stopped->find(std::this_thread::get_id());
    874   if (it != tape_is_stopped->end()) {
    875     return &(it->second);
    876   }
    877   return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
    878                .first->second);
    879 }
    880 #endif
    881 
    882 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
    883 
    884 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
    885 
    886 PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
    887   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
    888   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
    889   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
    890   tape->tape = new GradientTape(persistent == Py_True);
    891   Py_INCREF(tape);
    892   GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
    893   return reinterpret_cast<PyObject*>(tape);
    894 }
    895 
    896 PyObject* TFE_Py_TapeSetIsEmpty() {
    897   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
    898     Py_RETURN_TRUE;
    899   }
    900   Py_RETURN_FALSE;
    901 }
    902 
    903 void TFE_Py_TapeSetRemove(PyObject* tape) {
    904   auto* stack = GetTapeSet();
    905   stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
    906   // We kept a reference to the tape in the set to ensure it wouldn't get
    907   // deleted under us; cleaning it up here.
    908   Py_DECREF(tape);
    909 }
    910 
    911 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
    912   if (list == Py_None) {
    913     return {};
    914   }
    915   PyObject* seq = PySequence_Fast(list, "expected a sequence");
    916   if (seq == nullptr) {
    917     return {};
    918   }
    919   int len = PySequence_Size(list);
    920   std::vector<tensorflow::int64> tensor_ids;
    921   tensor_ids.reserve(len);
    922   for (int i = 0; i < len; ++i) {
    923     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
    924 #if PY_MAJOR_VERSION >= 3
    925     if (PyLong_Check(item)) {
    926 #else
    927     if (PyLong_Check(item) || PyInt_Check(item)) {
    928 #endif
    929       tensorflow::int64 id = MakeInt(item);
    930       tensor_ids.push_back(id);
    931     } else {
    932       tensor_ids.push_back(-1);
    933     }
    934   }
    935   Py_DECREF(seq);
    936   return tensor_ids;
    937 }
    938 
    939 PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
    940   if (tensors == Py_None) {
    941     Py_RETURN_FALSE;
    942   }
    943   if (*ThreadTapeIsStopped()) {
    944     Py_RETURN_FALSE;
    945   }
    946   auto* tape_set_ptr = GetTapeSet();
    947   if (tape_set_ptr->empty()) {
    948     Py_RETURN_FALSE;
    949   }
    950   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
    951   if (seq == nullptr) {
    952     return nullptr;
    953   }
    954   int len = PySequence_Fast_GET_SIZE(seq);
    955   // TODO(apassos) consider not building a list and changing the API to check
    956   // each tensor individually.
    957   std::vector<tensorflow::int64> tensor_ids;
    958   tensor_ids.reserve(len);
    959   for (int i = 0; i < len; ++i) {
    960     PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
    961     tensor_ids.push_back(FastTensorId(item));
    962   }
    963   Py_DECREF(seq);
    964   auto tape_set = *tape_set_ptr;
    965   for (TFE_Py_Tape* tape : tape_set) {
    966     if (tape->tape->ShouldRecord(tensor_ids)) {
    967       Py_RETURN_TRUE;
    968     }
    969   }
    970   Py_RETURN_FALSE;
    971 }
    972 
    973 void TFE_Py_TapeSetWatch(PyObject* tensor) {
    974   if (*ThreadTapeIsStopped()) {
    975     return;
    976   }
    977   tensorflow::int64 tensor_id = FastTensorId(tensor);
    978   if (PyErr_Occurred()) {
    979     return;
    980   }
    981   for (TFE_Py_Tape* tape : *GetTapeSet()) {
    982     tape->tape->Watch(tensor_id);
    983   }
    984 }
    985 
    986 static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
    987   if (EagerTensor_CheckExact(tensor)) {
    988     TFE_TensorHandle* t = EagerTensor_Handle(tensor);
    989     tensorflow::int64 id = EagerTensor_id(tensor);
    990     return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
    991   }
    992   tensorflow::int64 id = FastTensorId(tensor);
    993   if (PyErr_Occurred()) {
    994     return tensorflow::eager::TapeTensor{
    995         id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
    996   }
    997   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
    998   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
    999   Py_DECREF(dtype_object);
   1000   tensorflow::DataType dtype =
   1001       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
   1002   Py_DECREF(dtype_enum);
   1003   if (PyErr_Occurred() != nullptr) {
   1004     return tensorflow::eager::TapeTensor{id, dtype,
   1005                                          tensorflow::TensorShape({})};
   1006   }
   1007   static char _shape_tuple[] = "_shape_tuple";
   1008   PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
   1009   if (PyErr_Occurred() != nullptr) {
   1010     return tensorflow::eager::TapeTensor{id, dtype,
   1011                                          tensorflow::TensorShape({})};
   1012   }
   1013   auto l = MakeIntList(shape_tuple);
   1014   Py_DECREF(shape_tuple);
   1015   // Replace -1, which represents accidental Nones which can occur in graph mode
   1016   // and can cause errors in shape cosntruction with 0s.
   1017   for (auto& c : l) {
   1018     if (c < 0) {
   1019       c = 0;
   1020     }
   1021   }
   1022   tensorflow::TensorShape shape(l);
   1023   return tensorflow::eager::TapeTensor{id, dtype, shape};
   1024 }
   1025 
   1026 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
   1027   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1028   if (seq == nullptr) {
   1029     return {};
   1030   }
   1031   int len = PySequence_Fast_GET_SIZE(seq);
   1032   std::vector<tensorflow::int64> list;
   1033   list.reserve(len);
   1034   for (int i = 0; i < len; ++i) {
   1035     PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
   1036     list.push_back(FastTensorId(tensor));
   1037     if (PyErr_Occurred()) {
   1038       Py_DECREF(seq);
   1039       return list;
   1040     }
   1041   }
   1042   Py_DECREF(seq);
   1043   return list;
   1044 }
   1045 
   1046 void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
   1047   if (*ThreadTapeIsStopped()) {
   1048     return;
   1049   }
   1050   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1051     tape->tape->WatchVariable(variable);
   1052   }
   1053 }
   1054 
   1055 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
   1056   const std::unordered_set<PyObject*>& watched_variables =
   1057       reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
   1058   PyObject* result = PySet_New(nullptr);
   1059   for (PyObject* variable : watched_variables) {
   1060     PySet_Add(result, variable);
   1061   }
   1062   return result;
   1063 }
   1064 
   1065 void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
   1066                                    PyObject* input_tensors,
   1067                                    PyObject* backward_function) {
   1068   if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
   1069     return;
   1070   }
   1071   std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
   1072   if (PyErr_Occurred()) {
   1073     return;
   1074   }
   1075   std::vector<tensorflow::eager::TapeTensor> output_info;
   1076   PyObject* seq = PySequence_Fast(output_tensors,
   1077                                   "expected a sequence of integer tensor ids");
   1078   int len = PySequence_Size(output_tensors);
   1079   output_info.reserve(len);
   1080   for (int i = 0; i < len; ++i) {
   1081     output_info.push_back(
   1082         TapeTensorFromTensor(PySequence_Fast_GET_ITEM(seq, i)));
   1083     if (PyErr_Occurred() != nullptr) {
   1084       Py_DECREF(seq);
   1085       return;
   1086     }
   1087   }
   1088   Py_DECREF(seq);
   1089   string op_type_str;
   1090   if (PyBytes_Check(op_type)) {
   1091     op_type_str = PyBytes_AsString(op_type);
   1092   } else if (PyUnicode_Check(op_type)) {
   1093 #if PY_MAJOR_VERSION >= 3
   1094     op_type_str = PyUnicode_AsUTF8(op_type);
   1095 #else
   1096     PyObject* py_str = PyUnicode_AsUTF8String(op_type);
   1097     if (py_str == nullptr) return;
   1098     op_type_str = PyBytes_AS_STRING(py_str);
   1099     Py_DECREF(py_str);
   1100 #endif
   1101   } else {
   1102     PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
   1103     return;
   1104   }
   1105 
   1106   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1107     Py_INCREF(backward_function);
   1108     tape->tape->RecordOperation(
   1109         op_type_str, output_info, input_ids, backward_function,
   1110         [backward_function]() { Py_DECREF(backward_function); });
   1111   }
   1112 }
   1113 
   1114 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
   1115   for (TFE_Py_Tape* tape : SafeTapeSet()) {
   1116     tape->tape->DeleteTrace(tensor_id);
   1117   }
   1118 }
   1119 
   1120 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
   1121  public:
   1122   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
   1123 
   1124   tensorflow::Status Initialize() {
   1125     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
   1126     if (num_elements_ == nullptr) {
   1127       return tensorflow::errors::InvalidArgument("invalid vspace");
   1128     }
   1129     aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
   1130     if (aggregate_fn_ == nullptr) {
   1131       return tensorflow::errors::InvalidArgument("invalid vspace");
   1132     }
   1133     zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
   1134     if (zeros_ == nullptr) {
   1135       return tensorflow::errors::InvalidArgument("invalid vspace");
   1136     }
   1137     ones_ =
   1138         PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
   1139     if (ones_ == nullptr) {
   1140       return tensorflow::errors::InvalidArgument("invalid vspace");
   1141     }
   1142     return tensorflow::Status::OK();
   1143   }
   1144 
   1145   ~PyVSpace() override {
   1146     Py_XDECREF(num_elements_);
   1147     Py_XDECREF(aggregate_fn_);
   1148     Py_XDECREF(zeros_);
   1149     Py_XDECREF(ones_);
   1150   }
   1151 
   1152   tensorflow::int64 NumElements(PyObject* tensor) const final {
   1153     PyObject* arglist =
   1154         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
   1155     PyObject* result = PyEval_CallObject(num_elements_, arglist);
   1156     tensorflow::int64 r = MakeInt(result);
   1157     Py_DECREF(result);
   1158     Py_DECREF(arglist);
   1159     return r;
   1160   }
   1161 
   1162   PyObject* AggregateGradients(
   1163       tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
   1164     PyObject* list = PyList_New(gradient_tensors.size());
   1165     for (int i = 0; i < gradient_tensors.size(); ++i) {
   1166       // Note: stealing a reference to the gradient tensors.
   1167       CHECK(gradient_tensors[i] != nullptr);
   1168       CHECK(gradient_tensors[i] != Py_None);
   1169       PyList_SET_ITEM(list, i,
   1170                       reinterpret_cast<PyObject*>(gradient_tensors[i]));
   1171     }
   1172     PyObject* arglist = Py_BuildValue("(O)", list);
   1173     CHECK(arglist != nullptr);
   1174     PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
   1175     Py_DECREF(arglist);
   1176     Py_DECREF(list);
   1177     return result;
   1178   }
   1179 
   1180   PyObject* Zeros(tensorflow::TensorShape shape,
   1181                   tensorflow::DataType dtype) const final {
   1182     PyObject* py_shape = PyTuple_New(shape.dims());
   1183     for (int i = 0; i < shape.dims(); ++i) {
   1184       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
   1185     }
   1186     PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
   1187     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
   1188     PyObject* result = PyEval_CallObject(zeros_, arg_list);
   1189     Py_DECREF(arg_list);
   1190     Py_DECREF(py_dtype);
   1191     Py_DECREF(py_shape);
   1192     return reinterpret_cast<PyObject*>(result);
   1193   }
   1194 
   1195   PyObject* Ones(tensorflow::TensorShape shape,
   1196                  tensorflow::DataType dtype) const final {
   1197     PyObject* py_shape = PyTuple_New(shape.dims());
   1198     for (int i = 0; i < shape.dims(); ++i) {
   1199       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
   1200     }
   1201     PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
   1202     PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
   1203     PyObject* result = PyEval_CallObject(ones_, arg_list);
   1204     Py_DECREF(arg_list);
   1205     Py_DECREF(py_dtype);
   1206     Py_DECREF(py_shape);
   1207     return result;
   1208   }
   1209 
   1210   tensorflow::Status CallBackwardFunction(
   1211       PyObject* backward_function,
   1212       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
   1213       std::vector<PyObject*>* result) const final {
   1214     PyObject* grads = PyTuple_New(output_gradients.size());
   1215     for (int i = 0; i < output_gradients.size(); ++i) {
   1216       if (output_gradients[i] == nullptr) {
   1217         Py_INCREF(Py_None);
   1218         PyTuple_SET_ITEM(grads, i, Py_None);
   1219       } else {
   1220         PyTuple_SET_ITEM(grads, i,
   1221                          reinterpret_cast<PyObject*>(output_gradients[i]));
   1222       }
   1223     }
   1224     PyObject* py_result = PyEval_CallObject(
   1225         reinterpret_cast<PyObject*>(backward_function), grads);
   1226     Py_DECREF(grads);
   1227     if (py_result == nullptr) {
   1228       return tensorflow::errors::Internal("gradient function threw exceptions");
   1229     }
   1230     result->clear();
   1231     PyObject* seq =
   1232         PySequence_Fast(py_result, "expected a sequence of gradients");
   1233     if (seq == nullptr) {
   1234       return tensorflow::errors::InvalidArgument(
   1235           "gradient function did not return a list");
   1236     }
   1237     int len = PySequence_Fast_GET_SIZE(seq);
   1238     VLOG(1) << "Gradient length is " << len;
   1239     result->reserve(len);
   1240     for (int i = 0; i < len; ++i) {
   1241       PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
   1242       if (item == Py_None) {
   1243         result->push_back(nullptr);
   1244       } else {
   1245         Py_INCREF(item);
   1246         result->push_back(item);
   1247       }
   1248     }
   1249     Py_DECREF(seq);
   1250     Py_DECREF(py_result);
   1251     return tensorflow::Status::OK();
   1252   }
   1253 
   1254   void ReleaseBackwardFunction(PyObject* backward_function) const final {
   1255     Py_DECREF(backward_function);
   1256   }
   1257 
   1258   void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
   1259 
   1260  private:
   1261   PyObject* py_vspace_;
   1262 
   1263   PyObject* num_elements_;
   1264   PyObject* aggregate_fn_;
   1265   PyObject* zeros_;
   1266   PyObject* ones_;
   1267 };
   1268 
   1269 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   1270   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   1271   if (seq == nullptr) {
   1272     return {};
   1273   }
   1274   int len = PySequence_Fast_GET_SIZE(seq);
   1275   std::vector<PyObject*> list;
   1276   list.reserve(len);
   1277   for (int i = 0; i < len; ++i) {
   1278     list.push_back(PySequence_Fast_GET_ITEM(seq, i));
   1279   }
   1280   Py_DECREF(seq);
   1281   return list;
   1282 }
   1283 
   1284 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
   1285                               PyObject* target, PyObject* sources,
   1286                               PyObject* output_gradients, TF_Status* status) {
   1287   PyVSpace c_vspace(vspace);
   1288   if (!c_vspace.Initialize().ok()) {
   1289     return nullptr;
   1290   }
   1291 
   1292   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
   1293   if (PyErr_Occurred()) {
   1294     return nullptr;
   1295   }
   1296   std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
   1297   if (PyErr_Occurred()) {
   1298     return nullptr;
   1299   }
   1300   std::vector<PyObject*> outgrad_vec;
   1301   if (output_gradients != Py_None) {
   1302     outgrad_vec = MakeTensorList(output_gradients);
   1303     if (PyErr_Occurred()) {
   1304       return nullptr;
   1305     }
   1306     for (PyObject* tensor : outgrad_vec) {
   1307       // Calling the backward function will eat a reference to the tensors in
   1308       // outgrad_vec, so we need to increase their reference count.
   1309       Py_INCREF(tensor);
   1310     }
   1311   }
   1312   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
   1313   std::vector<PyObject*> result;
   1314   status->status = tape_obj->tape->ComputeGradient(
   1315       c_vspace, target_vec, sources_vec, outgrad_vec, &result);
   1316   if (!status->status.ok()) {
   1317     if (PyErr_Occurred()) {
   1318       // Do not propagate the erroneous status as that would swallow the
   1319       // exception which caused the problem.
   1320       status->status = tensorflow::Status::OK();
   1321     }
   1322     return nullptr;
   1323   }
   1324   if (!result.empty()) {
   1325     PyObject* py_result = PyList_New(result.size());
   1326     for (int i = 0; i < result.size(); ++i) {
   1327       if (result[i] == nullptr) {
   1328         Py_INCREF(Py_None);
   1329         result[i] = Py_None;
   1330       }
   1331       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
   1332     }
   1333     return py_result;
   1334   }
   1335   return PyList_New(0);
   1336 }
   1337 
   1338 namespace {
   1339 static const int kFastPathExecuteInputStartIndex = 6;
   1340 
   1341 PyObject* GetPythonObjectFromString(const char* s) {
   1342 #if PY_MAJOR_VERSION >= 3
   1343   return PyUnicode_FromString(s);
   1344 #else
   1345   return PyBytes_FromString(s);
   1346 #endif
   1347 }
   1348 
   1349 bool CheckEagerTensors(PyObject* seq, int start_index,
   1350                        const tensorflow::OpDef& op_def) {
   1351   for (int i = 0; i < op_def.input_arg_size(); i++) {
   1352     PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
   1353     if (!op_def.input_arg(i).number_attr().empty() ||
   1354         !op_def.input_arg(i).type_list_attr().empty()) {
   1355       // This item should be a list input.
   1356       if (!PyList_Check(item)) return false;
   1357       for (Py_ssize_t j = 0; j < PyList_Size(item); j++) {
   1358         if (!EagerTensor_CheckExact(PyList_GET_ITEM(item, j))) return false;
   1359       }
   1360     } else if (!EagerTensor_CheckExact(item)) {
   1361       return false;
   1362     }
   1363   }
   1364 
   1365   return true;
   1366 }
   1367 
   1368 // Adds input and type attr to the op, and to the list of flattened
   1369 // inputs/attrs.
   1370 bool AddInputToOp(PyObject* input, const tensorflow::OpDef::ArgDef* input_arg,
   1371                   std::vector<PyObject*>* flattened_attrs,
   1372                   std::vector<PyObject*>* flattened_inputs, TFE_Op* op,
   1373                   TF_Status* status) {
   1374   TFE_TensorHandle* input_handle = EagerTensor_Handle(input);
   1375   if (input_arg != nullptr && !input_arg->type_attr().empty()) {
   1376     auto dtype = TFE_TensorHandleDataType(input_handle);
   1377     TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype);
   1378     if (flattened_attrs != nullptr) {
   1379       flattened_attrs->push_back(
   1380           GetPythonObjectFromString(input_arg->type_attr().data()));
   1381       flattened_attrs->push_back(PyLong_FromLong(dtype));
   1382     }
   1383   }
   1384 
   1385   if (flattened_inputs != nullptr) {
   1386     flattened_inputs->push_back(input);
   1387   }
   1388   TFE_OpAddInput(op, input_handle, status);
   1389   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   1390     return false;
   1391   }
   1392   return true;
   1393 }
   1394 
   1395 const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
   1396   const char* op_name = TFE_GetPythonString(py_op_name);
   1397   if (op_name == nullptr) {
   1398     PyErr_SetString(PyExc_TypeError,
   1399                     Printf("expected a string for op_name, got %s instead",
   1400                            py_op_name->ob_type->tp_name)
   1401                         .c_str());
   1402     return nullptr;
   1403   }
   1404 
   1405   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
   1406   const tensorflow::Status lookup_status =
   1407       tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
   1408   if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) {
   1409     return nullptr;
   1410   }
   1411   return &op_reg_data->op_def;
   1412 }
   1413 
   1414 const char* GetDeviceName(PyObject* py_device_name) {
   1415   if (py_device_name != Py_None) {
   1416     return TFE_GetPythonString(py_device_name);
   1417   }
   1418   return nullptr;
   1419 }
   1420 
   1421 bool RaiseIfNotPyList(PyObject* list, const string& attr_name) {
   1422   if (!PyList_Check(list)) {
   1423     PyErr_SetString(PyExc_TypeError,
   1424                     Printf("expected a list for attr %s, got %s instead",
   1425                            attr_name.data(), list->ob_type->tp_name)
   1426                         .data());
   1427 
   1428     return false;
   1429   }
   1430   return true;
   1431 }
   1432 
   1433 bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
   1434                   const tensorflow::OpDef* op_def, PyObject* args,
   1435                   const std::vector<PyObject*>& flattened_inputs,
   1436                   const std::vector<PyObject*>& flattened_attrs,
   1437                   PyObject* flattened_result, PyObject* op_name, PyObject* name,
   1438                   PyObject* record_gradient_callback, PyObject* callbacks) {
   1439   PyObject* inputs = PyTuple_New(flattened_inputs.size());
   1440   for (int i = 0; i < flattened_inputs.size(); i++) {
   1441     PyObject* input = flattened_inputs[i];
   1442     Py_INCREF(input);
   1443     PyTuple_SET_ITEM(inputs, i, input);
   1444   }
   1445 
   1446   int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
   1447                                op_def->input_arg_size() -
   1448                                kFastPathExecuteInputStartIndex;
   1449   int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
   1450   PyObject* attrs = PyTuple_New(num_attrs);
   1451 
   1452   for (int i = 0; i < num_non_inferred_attrs; i++) {
   1453     auto* attr = PyTuple_GET_ITEM(
   1454         args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i);
   1455     Py_INCREF(attr);
   1456     PyTuple_SET_ITEM(attrs, i, attr);
   1457   }
   1458   for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
   1459     // Not INCREFing anything in flattened_attrs as each of those is a new
   1460     // reference, so allow the attrs tuple to steal the reference.
   1461     PyTuple_SET_ITEM(attrs, i, flattened_attrs.at(i - num_non_inferred_attrs));
   1462   }
   1463 
   1464   PyObject* callback_args =
   1465       Py_BuildValue("OOOOO", op_name, inputs, attrs, flattened_result, name);
   1466 
   1467   auto cleaner = tensorflow::gtl::MakeCleanup([inputs, attrs, callback_args] {
   1468     Py_DECREF(inputs);
   1469     Py_DECREF(attrs);
   1470     Py_DECREF(callback_args);
   1471   });
   1472 
   1473   if (run_gradient_callback) {
   1474     if (!PyCallable_Check(record_gradient_callback)) {
   1475       PyErr_SetString(PyExc_TypeError,
   1476                       Printf("expected a function for "
   1477                              "record_gradient_callback, got %s instead",
   1478                              record_gradient_callback->ob_type->tp_name)
   1479                           .c_str());
   1480       return false;
   1481     }
   1482 
   1483     PyObject* callback_result =
   1484         PyObject_CallObject(record_gradient_callback, callback_args);
   1485     if (!callback_result) {
   1486       return false;
   1487     }
   1488     Py_DECREF(callback_result);
   1489   }
   1490 
   1491   if (run_post_exec_callbacks) {
   1492     for (Py_ssize_t i = 0; i < PyList_Size(callbacks); i++) {
   1493       PyObject* callback_fn = PyList_GET_ITEM(callbacks, i);
   1494       if (!PyCallable_Check(callback_fn)) {
   1495         PyErr_SetString(
   1496             PyExc_TypeError,
   1497             Printf("expected a function for "
   1498                    "post execution callback in index %ld, got %s instead",
   1499                    i, callback_fn->ob_type->tp_name)
   1500                 .c_str());
   1501         return false;
   1502       }
   1503       PyObject* callback_result =
   1504           PyObject_CallObject(callback_fn, callback_args);
   1505       if (!callback_result) {
   1506         return false;
   1507       }
   1508       Py_DECREF(callback_result);
   1509     }
   1510   }
   1511 
   1512   return true;
   1513 }
   1514 
   1515 }  // namespace
   1516 
   1517 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
   1518   Py_ssize_t args_size = PyTuple_GET_SIZE(args);
   1519   if (args_size < kFastPathExecuteInputStartIndex) {
   1520     PyErr_SetString(
   1521         PyExc_ValueError,
   1522         Printf("There must be at least %d items in the input tuple.",
   1523                kFastPathExecuteInputStartIndex)
   1524             .c_str());
   1525     return nullptr;
   1526   }
   1527 
   1528   TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
   1529       PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
   1530   const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
   1531   PyObject* op_name = PyTuple_GET_ITEM(args, 2);
   1532   const tensorflow::OpDef* op_def = GetOpDef(op_name);
   1533   if (op_def == nullptr) return nullptr;
   1534   PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3);
   1535   PyObject* name = PyTuple_GET_ITEM(args, 4);
   1536   PyObject* callbacks = PyTuple_GET_ITEM(args, 5);
   1537 
   1538   if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
   1539     PyErr_SetString(
   1540         PyExc_ValueError,
   1541         Printf("Tuple size smaller than intended. Expected to be at least %d, "
   1542                "was %ld",
   1543                kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
   1544                args_size)
   1545             .c_str());
   1546     return nullptr;
   1547   }
   1548 
   1549   if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, *op_def)) {
   1550     RaiseFallbackException(
   1551         "This function does not handle the case of the path where "
   1552         "all inputs are not already EagerTensors.");
   1553     return nullptr;
   1554   }
   1555 
   1556   TF_Status* status = TF_NewStatus();
   1557   TFE_Op* op = TFE_NewOp(ctx, op_def->name().c_str(), status);
   1558   auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
   1559     TF_DeleteStatus(status);
   1560     TFE_DeleteOp(op);
   1561   });
   1562   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   1563     return nullptr;
   1564   }
   1565 
   1566   // Mapping of attr name to size - used to calculate the number of values
   1567   // to be expected by the TFE_Execute run.
   1568   tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
   1569 
   1570   // Set non-inferred attrs, including setting defaults if the attr is passed in
   1571   // as None.
   1572   for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
   1573        i < args_size; i += 2) {
   1574     PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
   1575     const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name));
   1576     PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
   1577 
   1578     // Not creating an index since most of the time there are not more than a
   1579     // few attrs.
   1580     // TODO(nareshmodi): Maybe include the index as part of the
   1581     // OpRegistrationData.
   1582     for (const auto& attr : op_def->attr()) {
   1583       if (attr_name == attr.name()) {
   1584         SetOpAttrWithDefaults(ctx, op, attr, attr_name.data(), py_attr_value,
   1585                               &attr_list_sizes, status);
   1586 
   1587         if (TF_GetCode(status) != TF_OK) {
   1588           RaiseFallbackException(TF_Message(status));
   1589           return nullptr;
   1590         }
   1591 
   1592         break;
   1593       }
   1594     }
   1595   }
   1596 
   1597   TFE_OpSetDevice(op, device_name, status);
   1598   if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   1599     return nullptr;
   1600   }
   1601 
   1602   // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
   1603   // (similar to benchmark_tf_gradient_function_*). Also consider using an
   1604   // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
   1605   // point out problems with heap allocs.
   1606   bool run_gradient_callback = !*ThreadTapeIsStopped() &&
   1607                                !GetTapeSet()->empty() &&
   1608                                record_gradient_callback != Py_None;
   1609   bool run_post_exec_callbacks =
   1610       callbacks != Py_None && PyList_Size(callbacks) > 0;
   1611   bool run_callbacks = run_gradient_callback || run_post_exec_callbacks;
   1612   // Flat attrs and inputs as required by the record_gradient call. The attrs
   1613   // here only contain inferred attrs (non-inferred attrs are added directly
   1614   // from the input args).
   1615   // All items in flattened_attrs contain new references.
   1616   // All items in flattened_inputs contain borrowed references.
   1617   // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
   1618   // directly.
   1619   std::unique_ptr<std::vector<PyObject*>> flattened_attrs = nullptr;
   1620   std::unique_ptr<std::vector<PyObject*>> flattened_inputs = nullptr;
   1621 
   1622   if (run_callbacks) {
   1623     flattened_attrs.reset(new std::vector<PyObject*>);
   1624     flattened_inputs.reset(new std::vector<PyObject*>);
   1625   }
   1626 
   1627   // Add inferred attrs and inputs.
   1628   // The following code might set duplicate type attrs. This will result in
   1629   // the CacheKey for the generated AttrBuilder possibly differing from
   1630   // those where the type attrs are correctly set. Inconsistent CacheKeys
   1631   // for ops means that there might be unnecessarily duplicated kernels.
   1632   // TODO(nareshmodi): Fix this.
   1633   for (int i = 0; i < op_def->input_arg_size(); i++) {
   1634     const auto& input_arg = op_def->input_arg(i);
   1635 
   1636     PyObject* input =
   1637         PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
   1638     if (!input_arg.number_attr().empty()) {
   1639       // The item is a homogeneous list.
   1640       if (!RaiseIfNotPyList(input, input_arg.number_attr())) return nullptr;
   1641       Py_ssize_t len = PyList_Size(input);
   1642 
   1643       TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
   1644       if (run_callbacks) {
   1645         flattened_attrs->push_back(
   1646             GetPythonObjectFromString(input_arg.number_attr().data()));
   1647         flattened_attrs->push_back(PyLong_FromLong(len));
   1648       }
   1649       attr_list_sizes[input_arg.number_attr()] = len;
   1650 
   1651       if (len > 0) {
   1652         // First item adds the type attr.
   1653         if (!AddInputToOp(PyList_GET_ITEM(input, 0), &input_arg,
   1654                           flattened_attrs.get(), flattened_inputs.get(), op,
   1655                           status)) {
   1656           return nullptr;
   1657         }
   1658 
   1659         for (Py_ssize_t j = 1; j < len; j++) {
   1660           // Since the list is homogeneous, we don't need to re-add the attr.
   1661           if (!AddInputToOp(PyList_GET_ITEM(input, j), nullptr /* input_arg */,
   1662                             nullptr /* flattened_attrs */,
   1663                             flattened_inputs.get(), op, status)) {
   1664             return nullptr;
   1665           }
   1666         }
   1667       }
   1668     } else if (!input_arg.type_list_attr().empty()) {
   1669       // The item is a heterogeneous list.
   1670       if (!RaiseIfNotPyList(input, input_arg.type_list_attr())) return nullptr;
   1671       const string& attr_name = input_arg.type_list_attr();
   1672       Py_ssize_t len = PyList_Size(input);
   1673       tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
   1674       PyObject* py_attr_value = nullptr;
   1675       if (run_callbacks) {
   1676         py_attr_value = PyTuple_New(len);
   1677       }
   1678       for (Py_ssize_t j = 0; j < len; j++) {
   1679         PyObject* py_input = PyList_GET_ITEM(input, j);
   1680         TFE_TensorHandle* input_handle = EagerTensor_Handle(py_input);
   1681         attr_value[j] = TFE_TensorHandleDataType(input_handle);
   1682 
   1683         TFE_OpAddInput(op, input_handle, status);
   1684         if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
   1685           return nullptr;
   1686         }
   1687 
   1688         if (run_callbacks) {
   1689           flattened_inputs->push_back(py_input);
   1690 
   1691           PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
   1692         }
   1693       }
   1694       if (run_callbacks) {
   1695         flattened_attrs->push_back(GetPythonObjectFromString(attr_name.data()));
   1696         flattened_attrs->push_back(py_attr_value);
   1697       }
   1698       TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
   1699                             attr_value.size());
   1700       attr_list_sizes[attr_name] = len;
   1701     } else {
   1702       // The item is a single item.
   1703       if (!AddInputToOp(input, &input_arg, flattened_attrs.get(),
   1704                         flattened_inputs.get(), op, status)) {
   1705         return nullptr;
   1706       }
   1707     }
   1708   }
   1709 
   1710   int num_retvals = 0;
   1711   for (int i = 0; i < op_def->output_arg_size(); i++) {
   1712     const auto& output_arg = op_def->output_arg(i);
   1713     if (!output_arg.number_attr().empty()) {
   1714       num_retvals += attr_list_sizes[output_arg.number_attr()];
   1715     } else if (!output_arg.type_list_attr().empty()) {
   1716       num_retvals += attr_list_sizes[output_arg.type_list_attr()];
   1717     } else {
   1718       num_retvals++;
   1719     }
   1720   }
   1721 
   1722   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
   1723 
   1724   Py_BEGIN_ALLOW_THREADS;
   1725   TFE_Execute(op, retvals.data(), &num_retvals, status);
   1726   Py_END_ALLOW_THREADS;
   1727   if (TF_GetCode(status) != TF_OK) {
   1728     // Augment the status with the op_name for easier debugging similar to
   1729     // TFE_Py_Execute.
   1730     TF_SetStatus(status, TF_GetCode(status),
   1731                  tensorflow::strings::StrCat(TF_Message(status), " [Op:",
   1732                                              TFE_GetPythonString(op_name), "]")
   1733                      .c_str());
   1734 
   1735     MaybeRaiseExceptionFromTFStatus(status, nullptr);
   1736     return nullptr;
   1737   }
   1738 
   1739   PyObject* flat_result = PyList_New(num_retvals);
   1740   for (int i = 0; i < num_retvals; ++i) {
   1741     PyList_SET_ITEM(flat_result, i, EagerTensorFromHandle(retvals[i]));
   1742   }
   1743 
   1744   if (run_callbacks &&
   1745       !RunCallbacks(run_gradient_callback, run_post_exec_callbacks, op_def,
   1746                     args, *flattened_inputs, *flattened_attrs, flat_result,
   1747                     op_name, name, record_gradient_callback, callbacks)) {
   1748     return nullptr;
   1749   }
   1750 
   1751   // Unflatten results.
   1752   if (op_def->output_arg_size() == 0) {
   1753     Py_RETURN_NONE;
   1754   }
   1755 
   1756   if (op_def->output_arg_size() == 1) {
   1757     if (!op_def->output_arg(0).number_attr().empty() ||
   1758         !op_def->output_arg(0).type_list_attr().empty()) {
   1759       return flat_result;
   1760     } else {
   1761       auto* result = PyList_GET_ITEM(flat_result, 0);
   1762       Py_INCREF(result);
   1763       Py_DECREF(flat_result);
   1764       return result;
   1765     }
   1766   }
   1767 
   1768   // Correctly output the results that are made into a namedtuple.
   1769   PyObject* result = PyList_New(op_def->output_arg_size());
   1770   int flat_result_index = 0;
   1771   for (int i = 0; i < op_def->output_arg_size(); i++) {
   1772     if (!op_def->output_arg(i).number_attr().empty()) {
   1773       int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
   1774       PyObject* inner_list = PyList_New(list_length);
   1775       for (int j = 0; j < list_length; j++) {
   1776         PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
   1777         Py_INCREF(obj);
   1778         PyList_SET_ITEM(inner_list, j, obj);
   1779       }
   1780       PyList_SET_ITEM(result, i, inner_list);
   1781     } else if (!op_def->output_arg(i).type_list_attr().empty()) {
   1782       int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
   1783       PyObject* inner_list = PyList_New(list_length);
   1784       for (int j = 0; j < list_length; j++) {
   1785         PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
   1786         Py_INCREF(obj);
   1787         PyList_SET_ITEM(inner_list, j, obj);
   1788       }
   1789       PyList_SET_ITEM(result, i, inner_list);
   1790     } else {
   1791       PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
   1792       Py_INCREF(obj);
   1793       PyList_SET_ITEM(result, i, obj);
   1794     }
   1795   }
   1796   Py_DECREF(flat_result);
   1797   return result;
   1798 }
   1799