1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <thread> 17 18 #include "tensorflow/python/eager/pywrap_tfe.h" 19 20 #include "tensorflow/c/c_api.h" 21 #include "tensorflow/c/c_api_internal.h" 22 #include "tensorflow/c/eager/c_api_internal.h" 23 #include "tensorflow/c/eager/tape.h" 24 #include "tensorflow/core/lib/gtl/cleanup.h" 25 #include "tensorflow/core/lib/gtl/compactptrset.h" 26 #include "tensorflow/core/lib/gtl/flatmap.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/lib/strings/stringprintf.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/protobuf.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/python/eager/pywrap_tensor.h" 33 34 using tensorflow::string; 35 using tensorflow::strings::Printf; 36 37 namespace { 38 39 #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ 40 bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ 41 type* value) { \ 42 if (check_fn(py_value)) { \ 43 *value = static_cast<type>(parse_fn(py_value)); \ 44 return true; \ 45 } else { \ 46 TF_SetStatus(status, TF_INVALID_ARGUMENT, \ 47 tensorflow::strings::StrCat( \ 48 "Expecting " #type " value for attr ", key, ", got ", \ 49 py_value->ob_type->tp_name) \ 50 .c_str()); \ 51 return false; \ 52 } \ 53 } 54 55 #if PY_MAJOR_VERSION >= 3 56 PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) 57 PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong) 58 #else 59 PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) 60 PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong) 61 PARSE_VALUE(ParseInt64LongValue, int64_t, PyLong_Check, PyLong_AsLong) 62 #endif 63 PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) 64 #undef PARSE_VALUE 65 66 Py_ssize_t TensorShapeNumDims(PyObject* value) { 67 const auto size = PySequence_Size(value); 68 if (size == -1) { 69 // TensorShape.__len__ raises an error in the scenario where the shape is an 70 // unknown, which needs to be cleared. 71 // TODO(nareshmodi): ensure that this is actually a TensorShape. 72 PyErr_Clear(); 73 } 74 return size; 75 } 76 77 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, 78 const char** value) { 79 if (PyBytes_Check(py_value)) { 80 *value = PyBytes_AsString(py_value); 81 return true; 82 } 83 #if PY_MAJOR_VERSION >= 3 84 if (PyUnicode_Check(py_value)) { 85 *value = PyUnicode_AsUTF8(py_value); 86 return true; 87 } 88 #endif 89 TF_SetStatus( 90 status, TF_INVALID_ARGUMENT, 91 tensorflow::strings::StrCat("Expecting a string value for attr ", key, 92 ", got ", py_value->ob_type->tp_name) 93 .c_str()); 94 return false; 95 } 96 97 bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, 98 unsigned char* value) { 99 *value = PyObject_IsTrue(py_value); 100 return true; 101 } 102 103 bool IsInteger(PyObject* py_value) { 104 #if PY_MAJOR_VERSION >= 3 105 return PyLong_Check(py_value); 106 #else 107 return PyInt_Check(py_value); 108 #endif 109 } 110 111 // The passed in py_value is expected to be an object of the python type 112 // dtypes.DType or an int. 113 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, 114 int* value) { 115 if (IsInteger(py_value)) { 116 return ParseIntValue(key, py_value, status, value); 117 } 118 119 PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum"); 120 if (py_type_enum == nullptr) { 121 return false; 122 } 123 124 if (!ParseIntValue(key, py_type_enum, status, value)) { 125 Py_DECREF(py_type_enum); 126 return false; 127 } 128 129 Py_DECREF(py_type_enum); 130 return true; 131 } 132 133 bool SetOpAttrList( 134 TFE_Op* op, const char* key, PyObject* py_list, TF_AttrType type, 135 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 136 TF_Status* status) { 137 if (!PySequence_Check(py_list)) { 138 TF_SetStatus( 139 status, TF_INVALID_ARGUMENT, 140 tensorflow::strings::StrCat("Expecting sequence value for attr ", key, 141 ", got ", py_list->ob_type->tp_name) 142 .c_str()); 143 return false; 144 } 145 const int num_values = PySequence_Size(py_list); 146 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; 147 148 #define PARSE_LIST(c_type, parse_fn) \ 149 std::unique_ptr<c_type[]> values(new c_type[num_values]); \ 150 for (int i = 0; i < num_values; ++i) { \ 151 auto py_value = PySequence_ITEM(py_list, i); \ 152 if (!parse_fn(key, py_value, status, &values[i])) return false; \ 153 } 154 155 if (type == TF_ATTR_STRING) { 156 PARSE_LIST(const char*, ParseStringValue); 157 TFE_OpSetAttrStringList(op, key, values.get(), num_values); 158 } else if (type == TF_ATTR_INT) { 159 PARSE_LIST(int64_t, ParseInt64Value); 160 TFE_OpSetAttrIntList(op, key, values.get(), num_values); 161 } else if (type == TF_ATTR_FLOAT) { 162 PARSE_LIST(float, ParseFloatValue); 163 TFE_OpSetAttrFloatList(op, key, values.get(), num_values); 164 } else if (type == TF_ATTR_BOOL) { 165 PARSE_LIST(unsigned char, ParseBoolValue); 166 TFE_OpSetAttrBoolList(op, key, values.get(), num_values); 167 } else if (type == TF_ATTR_TYPE) { 168 PARSE_LIST(int, ParseTypeValue); 169 TFE_OpSetAttrTypeList(op, key, 170 reinterpret_cast<const TF_DataType*>(values.get()), 171 num_values); 172 } else if (type == TF_ATTR_SHAPE) { 173 // Make one pass through the input counting the total number of 174 // dims across all the input lists. 175 int total_dims = 0; 176 for (int i = 0; i < num_values; ++i) { 177 auto py_value = PySequence_ITEM(py_list, i); 178 if (py_value != Py_None) { 179 if (!PySequence_Check(py_value)) { 180 TF_SetStatus( 181 status, TF_INVALID_ARGUMENT, 182 tensorflow::strings::StrCat( 183 "Expecting None or sequence value for element", i, 184 " of attr ", key, ", got ", py_value->ob_type->tp_name) 185 .c_str()); 186 return false; 187 } 188 const auto size = TensorShapeNumDims(py_value); 189 if (size >= 0) { 190 total_dims += size; 191 } 192 } 193 } 194 // Allocate a buffer that can fit all of the dims together. 195 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); 196 // Copy the input dims into the buffer and set dims to point to 197 // the start of each list's dims. 198 std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]); 199 std::unique_ptr<int[]> num_dims(new int[num_values]); 200 int64_t* offset = buffer.get(); 201 for (int i = 0; i < num_values; ++i) { 202 auto py_value = PySequence_ITEM(py_list, i); 203 if (py_value == Py_None) { 204 dims[i] = nullptr; 205 num_dims[i] = -1; 206 } else { 207 const auto size = TensorShapeNumDims(py_value); 208 if (size == -1) { 209 dims[i] = nullptr; 210 num_dims[i] = -1; 211 continue; 212 } 213 dims[i] = offset; 214 num_dims[i] = size; 215 for (int j = 0; j < size; ++j) { 216 auto inner_py_value = PySequence_ITEM(py_value, j); 217 if (inner_py_value == Py_None) { 218 *offset = -1; 219 } else if (!ParseInt64Value(key, inner_py_value, status, offset)) { 220 return false; 221 } 222 ++offset; 223 } 224 } 225 } 226 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, 227 status); 228 if (TF_GetCode(status) != TF_OK) return false; 229 } else { 230 TF_SetStatus(status, TF_UNIMPLEMENTED, 231 tensorflow::strings::StrCat("Attr ", key, 232 " has unhandled list type ", type) 233 .c_str()); 234 return false; 235 } 236 #undef PARSE_LIST 237 return true; 238 } 239 240 // This is only declared here since GetFunc makes a recursive call to 241 // SetOpAttrScalarDefault. 242 void SetOpAttrScalarDefault( 243 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, 244 const char* attr_name, 245 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 246 TF_Status* status); 247 248 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, 249 TF_Status* status) { 250 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); 251 for (const auto& attr : func.attr()) { 252 if (TF_GetCode(status) != TF_OK) return nullptr; 253 SetOpAttrScalarDefault(ctx, func_op, attr.second, attr.first.data(), 254 nullptr, status); 255 if (TF_GetCode(status) != TF_OK) return nullptr; 256 } 257 return func_op; 258 } 259 260 void SetOpAttrListDefault( 261 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, 262 const char* key, TF_AttrType type, 263 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 264 TF_Status* status) { 265 if (type == TF_ATTR_STRING) { 266 int num_values = attr.default_value().list().s_size(); 267 std::unique_ptr<const char* []> values(new const char*[num_values]); 268 (*attr_list_sizes)[key] = num_values; 269 for (int i = 0; i < num_values; i++) { 270 values[i] = attr.default_value().list().s(i).data(); 271 } 272 TFE_OpSetAttrStringList(op, key, values.get(), num_values); 273 } else if (type == TF_ATTR_INT) { 274 int num_values = attr.default_value().list().i_size(); 275 std::unique_ptr<int64_t[]> values(new int64_t[num_values]); 276 (*attr_list_sizes)[key] = num_values; 277 for (int i = 0; i < num_values; i++) { 278 values[i] = attr.default_value().list().i(i); 279 } 280 TFE_OpSetAttrIntList(op, key, values.get(), num_values); 281 } else if (type == TF_ATTR_FLOAT) { 282 int num_values = attr.default_value().list().f_size(); 283 std::unique_ptr<float[]> values(new float[num_values]); 284 (*attr_list_sizes)[key] = num_values; 285 for (int i = 0; i < num_values; i++) { 286 values[i] = attr.default_value().list().f(i); 287 } 288 TFE_OpSetAttrFloatList(op, key, values.get(), num_values); 289 } else if (type == TF_ATTR_BOOL) { 290 int num_values = attr.default_value().list().b_size(); 291 std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); 292 (*attr_list_sizes)[key] = num_values; 293 for (int i = 0; i < num_values; i++) { 294 values[i] = attr.default_value().list().b(i); 295 } 296 TFE_OpSetAttrBoolList(op, key, values.get(), num_values); 297 } else if (type == TF_ATTR_TYPE) { 298 int num_values = attr.default_value().list().type_size(); 299 std::unique_ptr<int[]> values(new int[num_values]); 300 (*attr_list_sizes)[key] = num_values; 301 for (int i = 0; i < num_values; i++) { 302 values[i] = attr.default_value().list().type(i); 303 } 304 TFE_OpSetAttrTypeList(op, key, 305 reinterpret_cast<const TF_DataType*>(values.get()), 306 attr.default_value().list().type_size()); 307 } else if (type == TF_ATTR_SHAPE) { 308 int num_values = attr.default_value().list().shape_size(); 309 (*attr_list_sizes)[key] = num_values; 310 int total_dims = 0; 311 for (int i = 0; i < num_values; ++i) { 312 if (!attr.default_value().list().shape(i).unknown_rank()) { 313 total_dims += attr.default_value().list().shape(i).dim_size(); 314 } 315 } 316 // Allocate a buffer that can fit all of the dims together. 317 std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); 318 // Copy the input dims into the buffer and set dims to point to 319 // the start of each list's dims. 320 std::unique_ptr<const int64_t* []> dims(new const int64_t*[num_values]); 321 std::unique_ptr<int[]> num_dims(new int[num_values]); 322 int64_t* offset = buffer.get(); 323 for (int i = 0; i < num_values; ++i) { 324 const auto& shape = attr.default_value().list().shape(i); 325 if (shape.unknown_rank()) { 326 dims[i] = nullptr; 327 num_dims[i] = -1; 328 } else { 329 for (int j = 0; j < shape.dim_size(); j++) { 330 *offset = shape.dim(j).size(); 331 ++offset; 332 } 333 } 334 } 335 TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, 336 status); 337 } else if (type == TF_ATTR_FUNC) { 338 int num_values = attr.default_value().list().func_size(); 339 (*attr_list_sizes)[key] = num_values; 340 std::unique_ptr<const TFE_Op* []> funcs(new const TFE_Op*[num_values]); 341 for (int i = 0; i < num_values; i++) { 342 funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); 343 } 344 TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); 345 } else { 346 TF_SetStatus(status, TF_UNIMPLEMENTED, 347 "Lists of tensors are not yet implemented for default valued " 348 "attributes for an operation."); 349 } 350 } 351 352 bool SetOpAttrScalar( 353 TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value, 354 TF_AttrType type, 355 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 356 TF_Status* status) { 357 if (type == TF_ATTR_STRING) { 358 const char* value; 359 if (!ParseStringValue(key, py_value, status, &value)) return false; 360 TFE_OpSetAttrString(op, key, value); 361 } else if (type == TF_ATTR_INT) { 362 int64_t value; 363 if (!ParseInt64Value(key, py_value, status, &value)) return false; 364 TFE_OpSetAttrInt(op, key, value); 365 // attr_list_sizes is set for all int attributes (since at this point we are 366 // not aware if that attribute might be used to calculate the size of an 367 // output list or not). 368 if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; 369 } else if (type == TF_ATTR_FLOAT) { 370 float value; 371 if (!ParseFloatValue(key, py_value, status, &value)) return false; 372 TFE_OpSetAttrFloat(op, key, value); 373 } else if (type == TF_ATTR_BOOL) { 374 unsigned char value; 375 if (!ParseBoolValue(key, py_value, status, &value)) return false; 376 TFE_OpSetAttrBool(op, key, value); 377 } else if (type == TF_ATTR_TYPE) { 378 int value; 379 if (!ParseTypeValue(key, py_value, status, &value)) return false; 380 TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); 381 } else if (type == TF_ATTR_SHAPE) { 382 if (py_value == Py_None) { 383 TFE_OpSetAttrShape(op, key, nullptr, -1, status); 384 } else { 385 if (!PySequence_Check(py_value)) { 386 TF_SetStatus(status, TF_INVALID_ARGUMENT, 387 tensorflow::strings::StrCat( 388 "Expecting None or sequence value for attr", key, 389 ", got ", py_value->ob_type->tp_name) 390 .c_str()); 391 return false; 392 } 393 const auto num_dims = TensorShapeNumDims(py_value); 394 if (num_dims == -1) { 395 TFE_OpSetAttrShape(op, key, nullptr, -1, status); 396 return true; 397 } 398 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); 399 for (int i = 0; i < num_dims; ++i) { 400 auto inner_py_value = PySequence_ITEM(py_value, i); 401 if (inner_py_value == Py_None) { 402 dims[i] = -1; 403 } else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) { 404 return false; 405 } 406 } 407 TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); 408 } 409 if (TF_GetCode(status) != TF_OK) return false; 410 } else if (type == TF_ATTR_FUNC) { 411 // Allow: 412 // (1) String function name, OR 413 // (2) A Python object with a .name attribute 414 // (A crude test for being a 415 // tensorflow.python.framework.function._DefinedFunction) 416 // (which is what the various "defun" or "Defun" decorators do). 417 // And in the future also allow an object that can encapsulate 418 // the function name and its attribute values. 419 const char* func_name = nullptr; 420 if (!ParseStringValue(key, py_value, status, &func_name)) { 421 PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); 422 if (name_attr == nullptr || 423 !ParseStringValue(key, name_attr, status, &func_name)) { 424 TF_SetStatus( 425 status, TF_INVALID_ARGUMENT, 426 tensorflow::strings::StrCat( 427 "unable to set function value attribute from a ", 428 py_value->ob_type->tp_name, 429 " object. If you think this is an error, please file an issue " 430 "at https://github.com/tensorflow/tensorflow/issues/new") 431 .c_str()); 432 return false; 433 } 434 } 435 TFE_Op* func = TFE_NewOp(ctx, func_name, status); 436 if (TF_GetCode(status) != TF_OK) return false; 437 TFE_OpSetAttrFunction(op, key, func); 438 TFE_DeleteOp(func); 439 } else { 440 TF_SetStatus( 441 status, TF_UNIMPLEMENTED, 442 tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type) 443 .c_str()); 444 return false; 445 } 446 return true; 447 } 448 449 void SetOpAttrScalarDefault( 450 TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, 451 const char* attr_name, 452 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 453 TF_Status* status) { 454 switch (default_value.value_case()) { 455 case tensorflow::AttrValue::kS: 456 TFE_OpSetAttrString(op, attr_name, default_value.s().data()); 457 break; 458 case tensorflow::AttrValue::kI: 459 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i())); 460 (*attr_list_sizes)[attr_name] = default_value.i(); 461 break; 462 case tensorflow::AttrValue::kF: 463 TFE_OpSetAttrFloat(op, attr_name, default_value.f()); 464 break; 465 case tensorflow::AttrValue::kB: 466 TFE_OpSetAttrBool(op, attr_name, default_value.b()); 467 break; 468 case tensorflow::AttrValue::kType: 469 TFE_OpSetAttrType(op, attr_name, 470 static_cast<TF_DataType>(default_value.type())); 471 break; 472 case tensorflow::AttrValue::kShape: { 473 const auto& tensor_shape = default_value.shape(); 474 if (tensor_shape.unknown_rank()) { 475 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); 476 } else { 477 const auto num_dims = tensor_shape.dim_size(); 478 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); 479 for (int i = 0; i < num_dims; ++i) { 480 dims[i] = tensor_shape.dim(i).size(); 481 } 482 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); 483 } 484 } break; 485 case tensorflow::AttrValue::kFunc: { 486 const auto func_op = GetFunc(ctx, default_value.func(), status); 487 if (TF_GetCode(status) != TF_OK) return; 488 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList 489 // require TFE_Op* and just convert it internally a NameAttrValue, so 490 // consider adding an overload to the C API to make this case easier. 491 TFE_OpSetAttrFunction(op, attr_name, func_op); 492 } break; 493 case tensorflow::AttrValue::kList: 494 TF_FALLTHROUGH_INTENDED; 495 case tensorflow::AttrValue::kTensor: 496 TF_FALLTHROUGH_INTENDED; 497 case tensorflow::AttrValue::kPlaceholder: 498 TF_FALLTHROUGH_INTENDED; 499 case tensorflow::AttrValue::VALUE_NOT_SET: 500 TF_SetStatus( 501 status, TF_UNIMPLEMENTED, 502 tensorflow::strings::StrCat("Unable to get setfor default value: ", 503 default_value.DebugString()) 504 .data()); 505 } 506 } 507 508 // start_index is the index at which the Tuple/List attrs will start getting 509 // processed. 510 void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, 511 TF_Status* out_status) { 512 if (attrs == Py_None) return; 513 Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index; 514 if ((len & 1) != 0) { 515 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, 516 "Expecting attrs tuple to have even length."); 517 return; 518 } 519 // Parse attrs 520 for (Py_ssize_t i = 0; i < len; i += 2) { 521 PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i); 522 PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1); 523 #if PY_MAJOR_VERSION >= 3 524 const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key) 525 : PyUnicode_AsUTF8(py_key); 526 #else 527 const char* key = PyBytes_AsString(py_key); 528 #endif 529 unsigned char is_list = 0; 530 const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); 531 if (TF_GetCode(out_status) != TF_OK) return; 532 if (is_list != 0) { 533 if (!SetOpAttrList(op, key, py_value, type, nullptr, out_status)) return; 534 } else { 535 if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) 536 return; 537 } 538 } 539 } 540 541 // This function will set the op attrs required. If an attr has the value of 542 // None, then it will read the AttrDef to get the default value and set that 543 // instead. Any failure in this function will simply fall back to the slow 544 // path. 545 void SetOpAttrWithDefaults( 546 TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, 547 const char* attr_name, PyObject* attr_value, 548 tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, 549 TF_Status* status) { 550 unsigned char is_list = 0; 551 const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); 552 if (TF_GetCode(status) != TF_OK) return; 553 if (attr_value == Py_None) { 554 if (is_list != 0) { 555 SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, 556 status); 557 } else { 558 SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, 559 attr_list_sizes, status); 560 } 561 } else { 562 if (is_list != 0) { 563 SetOpAttrList(op, attr_name, attr_value, type, attr_list_sizes, status); 564 } else { 565 SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, 566 status); 567 } 568 } 569 } 570 571 // Python subclass of Exception that is created on not ok Status. 572 tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); 573 PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; 574 575 // Python subclass of Exception that is created to signal fallback. 576 PyObject* fallback_exception_class = nullptr; 577 578 tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); 579 tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; 580 581 } // namespace 582 583 void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, 584 const char* op_name, TFE_InputTensorHandles* inputs, 585 PyObject* attrs, TFE_OutputTensorHandles* outputs, 586 TF_Status* out_status) { 587 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status); 588 if (TF_GetCode(out_status) != TF_OK) return; 589 TFE_OpSetDevice(op, device_name, out_status); 590 if (TF_GetCode(out_status) == TF_OK) { 591 for (int i = 0; i < inputs->size() && TF_GetCode(out_status) == TF_OK; 592 ++i) { 593 TFE_OpAddInput(op, inputs->at(i), out_status); 594 } 595 } 596 if (TF_GetCode(out_status) == TF_OK) { 597 SetOpAttrs(ctx, op, attrs, 0, out_status); 598 } 599 Py_BEGIN_ALLOW_THREADS; 600 if (TF_GetCode(out_status) == TF_OK) { 601 int num_outputs = outputs->size(); 602 TFE_Execute(op, outputs->data(), &num_outputs, out_status); 603 outputs->resize(num_outputs); 604 } 605 if (TF_GetCode(out_status) != TF_OK) { 606 TF_SetStatus(out_status, TF_GetCode(out_status), 607 tensorflow::strings::StrCat(TF_Message(out_status), 608 " [Op:", op_name, "]") 609 .c_str()); 610 } 611 TFE_DeleteOp(op); 612 Py_END_ALLOW_THREADS; 613 } 614 615 PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { 616 tensorflow::mutex_lock l(exception_class_mutex); 617 if (exception_class != nullptr) { 618 Py_DECREF(exception_class); 619 } 620 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { 621 exception_class = nullptr; 622 PyErr_SetString(PyExc_TypeError, 623 "TFE_Py_RegisterExceptionClass: " 624 "Registered class should be subclass of Exception."); 625 return nullptr; 626 } else { 627 Py_INCREF(e); 628 exception_class = e; 629 Py_RETURN_NONE; 630 } 631 } 632 633 PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { 634 if (fallback_exception_class != nullptr) { 635 Py_DECREF(fallback_exception_class); 636 } 637 if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { 638 fallback_exception_class = nullptr; 639 PyErr_SetString(PyExc_TypeError, 640 "TFE_Py_RegisterFallbackExceptionClass: " 641 "Registered class should be subclass of Exception."); 642 return nullptr; 643 } else { 644 Py_INCREF(e); 645 fallback_exception_class = e; 646 Py_RETURN_NONE; 647 } 648 } 649 650 void RaiseFallbackException(const char* message) { 651 if (fallback_exception_class != nullptr) { 652 PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message)); 653 return; 654 } 655 656 PyErr_SetString( 657 PyExc_RuntimeError, 658 tensorflow::strings::StrCat( 659 "Fallback exception type not set, attempting to fallback due to ", 660 message) 661 .data()); 662 } 663 664 int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { 665 if (TF_GetCode(status) == TF_OK) return 0; 666 const char* msg = TF_Message(status); 667 if (exception == nullptr) { 668 tensorflow::mutex_lock l(exception_class_mutex); 669 if (exception_class != nullptr) { 670 PyErr_SetObject(exception_class, 671 Py_BuildValue("si", msg, TF_GetCode(status))); 672 return -1; 673 } else { 674 exception = PyExc_RuntimeError; 675 } 676 } 677 // May be update already set exception. 678 PyErr_SetString(exception, msg); 679 return -1; 680 } 681 682 int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, 683 PyObject* exception) { 684 if (status.ok()) return 0; 685 const char* msg = status.error_message().c_str(); 686 if (exception == nullptr) { 687 tensorflow::mutex_lock l(exception_class_mutex); 688 if (exception_class != nullptr) { 689 PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code())); 690 return -1; 691 } else { 692 exception = PyExc_RuntimeError; 693 } 694 } 695 // May be update already set exception. 696 PyErr_SetString(exception, msg); 697 return -1; 698 } 699 700 char* TFE_GetPythonString(PyObject* o) { 701 if (PyBytes_Check(o)) { 702 return PyBytes_AsString(o); 703 } 704 #if PY_MAJOR_VERSION >= 3 705 if (PyUnicode_Check(o)) { 706 return PyUnicode_AsUTF8(o); 707 } 708 #endif 709 return nullptr; 710 } 711 712 int64_t get_uid() { 713 tensorflow::mutex_lock l(_uid_mutex); 714 return _uid++; 715 } 716 717 PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } 718 719 void TFE_DeleteContextCapsule(PyObject* context) { 720 TF_Status* status = TF_NewStatus(); 721 TFE_Context* ctx = 722 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); 723 TFE_DeleteContext(ctx, status); 724 TF_DeleteStatus(status); 725 } 726 727 static tensorflow::int64 MakeInt(PyObject* integer) { 728 #if PY_MAJOR_VERSION >= 3 729 return PyLong_AsLong(integer); 730 #else 731 return PyInt_AsLong(integer); 732 #endif 733 } 734 735 static tensorflow::int64 FastTensorId(PyObject* tensor) { 736 if (EagerTensor_CheckExact(tensor)) { 737 return EagerTensor_id(tensor); 738 } 739 PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); 740 if (id_field == nullptr) { 741 return -1; 742 } 743 tensorflow::int64 id = MakeInt(id_field); 744 Py_DECREF(id_field); 745 return id; 746 } 747 748 class GradientTape 749 : public tensorflow::eager::GradientTape<PyObject, PyObject> { 750 public: 751 explicit GradientTape(bool persistent) 752 : tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {} 753 754 virtual ~GradientTape() { 755 for (PyObject* v : watched_variables_) { 756 Py_DECREF(v); 757 } 758 } 759 760 void WatchVariable(PyObject* v) { 761 auto insert_result = watched_variables_.insert(v); 762 if (insert_result.second) { 763 // Only increment the reference count if we aren't already watching this 764 // variable. 765 Py_INCREF(v); 766 } 767 PyObject* handle = PyObject_GetAttrString(v, "handle"); 768 if (handle == nullptr) { 769 return; 770 } 771 tensorflow::int64 id = FastTensorId(handle); 772 Py_DECREF(handle); 773 if (!PyErr_Occurred()) { 774 this->Watch(id); 775 } 776 } 777 778 const std::unordered_set<PyObject*> WatchedVariables() { 779 return watched_variables_; 780 } 781 782 private: 783 std::unordered_set<PyObject*> watched_variables_; 784 }; 785 786 typedef struct { 787 PyObject_HEAD 788 /* Type-specific fields go here. */ 789 GradientTape* tape; 790 } TFE_Py_Tape; 791 792 static void TFE_Py_Tape_Delete(PyObject* tape) { 793 delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; 794 Py_TYPE(tape)->tp_free(tape); 795 } 796 797 static PyTypeObject TFE_Py_Tape_Type = { 798 PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */ 799 sizeof(TFE_Py_Tape), /* tp_basicsize */ 800 0, /* tp_itemsize */ 801 &TFE_Py_Tape_Delete, /* tp_dealloc */ 802 nullptr, /* tp_print */ 803 nullptr, /* tp_getattr */ 804 nullptr, /* tp_setattr */ 805 nullptr, /* tp_reserved */ 806 nullptr, /* tp_repr */ 807 nullptr, /* tp_as_number */ 808 nullptr, /* tp_as_sequence */ 809 nullptr, /* tp_as_mapping */ 810 nullptr, /* tp_hash */ 811 nullptr, /* tp_call */ 812 nullptr, /* tp_str */ 813 nullptr, /* tp_getattro */ 814 nullptr, /* tp_setattro */ 815 nullptr, /* tp_as_buffer */ 816 Py_TPFLAGS_DEFAULT, /* tp_flags */ 817 "TFE_Py_Tape objects", /* tp_doc */ 818 }; 819 820 // Note: in the current design no mutex is needed here because of the python 821 // GIL, which is always held when any TFE_Py_* methods are called. We should 822 // revisit this if/when decide to not hold the GIL while manipulating the tape 823 // stack. 824 static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr; 825 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { 826 if (tape_set == nullptr) { 827 tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>; 828 } 829 return tape_set; 830 } 831 832 // A safe copy of the current tapeset. Does not get affected by other python 833 // threads changing the set of active tapes. 834 class SafeTapeSet { 835 public: 836 SafeTapeSet() : tape_set_(*GetTapeSet()) { 837 for (auto* tape : tape_set_) { 838 Py_INCREF(tape); 839 } 840 } 841 842 ~SafeTapeSet() { 843 for (auto* tape : tape_set_) { 844 Py_DECREF(tape); 845 } 846 } 847 848 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator begin() { 849 return tape_set_.begin(); 850 } 851 852 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>::const_iterator end() { 853 return tape_set_.end(); 854 } 855 856 private: 857 tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_; 858 }; 859 860 // xcode 7 doesn't define thread_local, so for compatibility we implement our 861 // own. TODO(apassos) remove once we can deprecate xcode 7. 862 #ifndef __APPLE__ 863 bool* ThreadTapeIsStopped() { 864 thread_local bool thread_tape_is_stopped{false}; 865 return &thread_tape_is_stopped; 866 } 867 #else 868 static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr; 869 bool* ThreadTapeIsStopped() { 870 if (tape_is_stopped == nullptr) { 871 tape_is_stopped = new std::unordered_map<std::thread::id, bool>; 872 } 873 auto it = tape_is_stopped->find(std::this_thread::get_id()); 874 if (it != tape_is_stopped->end()) { 875 return &(it->second); 876 } 877 return &(tape_is_stopped->emplace(std::this_thread::get_id(), false) 878 .first->second); 879 } 880 #endif 881 882 void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } 883 884 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } 885 886 PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { 887 TFE_Py_Tape_Type.tp_new = PyType_GenericNew; 888 if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; 889 TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); 890 tape->tape = new GradientTape(persistent == Py_True); 891 Py_INCREF(tape); 892 GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)); 893 return reinterpret_cast<PyObject*>(tape); 894 } 895 896 PyObject* TFE_Py_TapeSetIsEmpty() { 897 if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { 898 Py_RETURN_TRUE; 899 } 900 Py_RETURN_FALSE; 901 } 902 903 void TFE_Py_TapeSetRemove(PyObject* tape) { 904 auto* stack = GetTapeSet(); 905 stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); 906 // We kept a reference to the tape in the set to ensure it wouldn't get 907 // deleted under us; cleaning it up here. 908 Py_DECREF(tape); 909 } 910 911 static std::vector<tensorflow::int64> MakeIntList(PyObject* list) { 912 if (list == Py_None) { 913 return {}; 914 } 915 PyObject* seq = PySequence_Fast(list, "expected a sequence"); 916 if (seq == nullptr) { 917 return {}; 918 } 919 int len = PySequence_Size(list); 920 std::vector<tensorflow::int64> tensor_ids; 921 tensor_ids.reserve(len); 922 for (int i = 0; i < len; ++i) { 923 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 924 #if PY_MAJOR_VERSION >= 3 925 if (PyLong_Check(item)) { 926 #else 927 if (PyLong_Check(item) || PyInt_Check(item)) { 928 #endif 929 tensorflow::int64 id = MakeInt(item); 930 tensor_ids.push_back(id); 931 } else { 932 tensor_ids.push_back(-1); 933 } 934 } 935 Py_DECREF(seq); 936 return tensor_ids; 937 } 938 939 PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { 940 if (tensors == Py_None) { 941 Py_RETURN_FALSE; 942 } 943 if (*ThreadTapeIsStopped()) { 944 Py_RETURN_FALSE; 945 } 946 auto* tape_set_ptr = GetTapeSet(); 947 if (tape_set_ptr->empty()) { 948 Py_RETURN_FALSE; 949 } 950 PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); 951 if (seq == nullptr) { 952 return nullptr; 953 } 954 int len = PySequence_Fast_GET_SIZE(seq); 955 // TODO(apassos) consider not building a list and changing the API to check 956 // each tensor individually. 957 std::vector<tensorflow::int64> tensor_ids; 958 tensor_ids.reserve(len); 959 for (int i = 0; i < len; ++i) { 960 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 961 tensor_ids.push_back(FastTensorId(item)); 962 } 963 Py_DECREF(seq); 964 auto tape_set = *tape_set_ptr; 965 for (TFE_Py_Tape* tape : tape_set) { 966 if (tape->tape->ShouldRecord(tensor_ids)) { 967 Py_RETURN_TRUE; 968 } 969 } 970 Py_RETURN_FALSE; 971 } 972 973 void TFE_Py_TapeSetWatch(PyObject* tensor) { 974 if (*ThreadTapeIsStopped()) { 975 return; 976 } 977 tensorflow::int64 tensor_id = FastTensorId(tensor); 978 if (PyErr_Occurred()) { 979 return; 980 } 981 for (TFE_Py_Tape* tape : *GetTapeSet()) { 982 tape->tape->Watch(tensor_id); 983 } 984 } 985 986 static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { 987 if (EagerTensor_CheckExact(tensor)) { 988 TFE_TensorHandle* t = EagerTensor_Handle(tensor); 989 tensorflow::int64 id = EagerTensor_id(tensor); 990 return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()}; 991 } 992 tensorflow::int64 id = FastTensorId(tensor); 993 if (PyErr_Occurred()) { 994 return tensorflow::eager::TapeTensor{ 995 id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})}; 996 } 997 PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); 998 PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); 999 Py_DECREF(dtype_object); 1000 tensorflow::DataType dtype = 1001 static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); 1002 Py_DECREF(dtype_enum); 1003 if (PyErr_Occurred() != nullptr) { 1004 return tensorflow::eager::TapeTensor{id, dtype, 1005 tensorflow::TensorShape({})}; 1006 } 1007 static char _shape_tuple[] = "_shape_tuple"; 1008 PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr); 1009 if (PyErr_Occurred() != nullptr) { 1010 return tensorflow::eager::TapeTensor{id, dtype, 1011 tensorflow::TensorShape({})}; 1012 } 1013 auto l = MakeIntList(shape_tuple); 1014 Py_DECREF(shape_tuple); 1015 // Replace -1, which represents accidental Nones which can occur in graph mode 1016 // and can cause errors in shape cosntruction with 0s. 1017 for (auto& c : l) { 1018 if (c < 0) { 1019 c = 0; 1020 } 1021 } 1022 tensorflow::TensorShape shape(l); 1023 return tensorflow::eager::TapeTensor{id, dtype, shape}; 1024 } 1025 1026 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { 1027 PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); 1028 if (seq == nullptr) { 1029 return {}; 1030 } 1031 int len = PySequence_Fast_GET_SIZE(seq); 1032 std::vector<tensorflow::int64> list; 1033 list.reserve(len); 1034 for (int i = 0; i < len; ++i) { 1035 PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); 1036 list.push_back(FastTensorId(tensor)); 1037 if (PyErr_Occurred()) { 1038 Py_DECREF(seq); 1039 return list; 1040 } 1041 } 1042 Py_DECREF(seq); 1043 return list; 1044 } 1045 1046 void TFE_Py_TapeSetWatchVariable(PyObject* variable) { 1047 if (*ThreadTapeIsStopped()) { 1048 return; 1049 } 1050 for (TFE_Py_Tape* tape : SafeTapeSet()) { 1051 tape->tape->WatchVariable(variable); 1052 } 1053 } 1054 1055 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { 1056 const std::unordered_set<PyObject*>& watched_variables = 1057 reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables(); 1058 PyObject* result = PySet_New(nullptr); 1059 for (PyObject* variable : watched_variables) { 1060 PySet_Add(result, variable); 1061 } 1062 return result; 1063 } 1064 1065 void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, 1066 PyObject* input_tensors, 1067 PyObject* backward_function) { 1068 if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { 1069 return; 1070 } 1071 std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); 1072 if (PyErr_Occurred()) { 1073 return; 1074 } 1075 std::vector<tensorflow::eager::TapeTensor> output_info; 1076 PyObject* seq = PySequence_Fast(output_tensors, 1077 "expected a sequence of integer tensor ids"); 1078 int len = PySequence_Size(output_tensors); 1079 output_info.reserve(len); 1080 for (int i = 0; i < len; ++i) { 1081 output_info.push_back( 1082 TapeTensorFromTensor(PySequence_Fast_GET_ITEM(seq, i))); 1083 if (PyErr_Occurred() != nullptr) { 1084 Py_DECREF(seq); 1085 return; 1086 } 1087 } 1088 Py_DECREF(seq); 1089 string op_type_str; 1090 if (PyBytes_Check(op_type)) { 1091 op_type_str = PyBytes_AsString(op_type); 1092 } else if (PyUnicode_Check(op_type)) { 1093 #if PY_MAJOR_VERSION >= 3 1094 op_type_str = PyUnicode_AsUTF8(op_type); 1095 #else 1096 PyObject* py_str = PyUnicode_AsUTF8String(op_type); 1097 if (py_str == nullptr) return; 1098 op_type_str = PyBytes_AS_STRING(py_str); 1099 Py_DECREF(py_str); 1100 #endif 1101 } else { 1102 PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); 1103 return; 1104 } 1105 1106 for (TFE_Py_Tape* tape : SafeTapeSet()) { 1107 Py_INCREF(backward_function); 1108 tape->tape->RecordOperation( 1109 op_type_str, output_info, input_ids, backward_function, 1110 [backward_function]() { Py_DECREF(backward_function); }); 1111 } 1112 } 1113 1114 void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { 1115 for (TFE_Py_Tape* tape : SafeTapeSet()) { 1116 tape->tape->DeleteTrace(tensor_id); 1117 } 1118 } 1119 1120 class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> { 1121 public: 1122 explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {} 1123 1124 tensorflow::Status Initialize() { 1125 num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); 1126 if (num_elements_ == nullptr) { 1127 return tensorflow::errors::InvalidArgument("invalid vspace"); 1128 } 1129 aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); 1130 if (aggregate_fn_ == nullptr) { 1131 return tensorflow::errors::InvalidArgument("invalid vspace"); 1132 } 1133 zeros_ = PyObject_GetAttrString(py_vspace_, "zeros"); 1134 if (zeros_ == nullptr) { 1135 return tensorflow::errors::InvalidArgument("invalid vspace"); 1136 } 1137 ones_ = 1138 PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones"); 1139 if (ones_ == nullptr) { 1140 return tensorflow::errors::InvalidArgument("invalid vspace"); 1141 } 1142 return tensorflow::Status::OK(); 1143 } 1144 1145 ~PyVSpace() override { 1146 Py_XDECREF(num_elements_); 1147 Py_XDECREF(aggregate_fn_); 1148 Py_XDECREF(zeros_); 1149 Py_XDECREF(ones_); 1150 } 1151 1152 tensorflow::int64 NumElements(PyObject* tensor) const final { 1153 PyObject* arglist = 1154 Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); 1155 PyObject* result = PyEval_CallObject(num_elements_, arglist); 1156 tensorflow::int64 r = MakeInt(result); 1157 Py_DECREF(result); 1158 Py_DECREF(arglist); 1159 return r; 1160 } 1161 1162 PyObject* AggregateGradients( 1163 tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { 1164 PyObject* list = PyList_New(gradient_tensors.size()); 1165 for (int i = 0; i < gradient_tensors.size(); ++i) { 1166 // Note: stealing a reference to the gradient tensors. 1167 CHECK(gradient_tensors[i] != nullptr); 1168 CHECK(gradient_tensors[i] != Py_None); 1169 PyList_SET_ITEM(list, i, 1170 reinterpret_cast<PyObject*>(gradient_tensors[i])); 1171 } 1172 PyObject* arglist = Py_BuildValue("(O)", list); 1173 CHECK(arglist != nullptr); 1174 PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); 1175 Py_DECREF(arglist); 1176 Py_DECREF(list); 1177 return result; 1178 } 1179 1180 PyObject* Zeros(tensorflow::TensorShape shape, 1181 tensorflow::DataType dtype) const final { 1182 PyObject* py_shape = PyTuple_New(shape.dims()); 1183 for (int i = 0; i < shape.dims(); ++i) { 1184 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); 1185 } 1186 PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); 1187 PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); 1188 PyObject* result = PyEval_CallObject(zeros_, arg_list); 1189 Py_DECREF(arg_list); 1190 Py_DECREF(py_dtype); 1191 Py_DECREF(py_shape); 1192 return reinterpret_cast<PyObject*>(result); 1193 } 1194 1195 PyObject* Ones(tensorflow::TensorShape shape, 1196 tensorflow::DataType dtype) const final { 1197 PyObject* py_shape = PyTuple_New(shape.dims()); 1198 for (int i = 0; i < shape.dims(); ++i) { 1199 PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); 1200 } 1201 PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); 1202 PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); 1203 PyObject* result = PyEval_CallObject(ones_, arg_list); 1204 Py_DECREF(arg_list); 1205 Py_DECREF(py_dtype); 1206 Py_DECREF(py_shape); 1207 return result; 1208 } 1209 1210 tensorflow::Status CallBackwardFunction( 1211 PyObject* backward_function, 1212 tensorflow::gtl::ArraySlice<PyObject*> output_gradients, 1213 std::vector<PyObject*>* result) const final { 1214 PyObject* grads = PyTuple_New(output_gradients.size()); 1215 for (int i = 0; i < output_gradients.size(); ++i) { 1216 if (output_gradients[i] == nullptr) { 1217 Py_INCREF(Py_None); 1218 PyTuple_SET_ITEM(grads, i, Py_None); 1219 } else { 1220 PyTuple_SET_ITEM(grads, i, 1221 reinterpret_cast<PyObject*>(output_gradients[i])); 1222 } 1223 } 1224 PyObject* py_result = PyEval_CallObject( 1225 reinterpret_cast<PyObject*>(backward_function), grads); 1226 Py_DECREF(grads); 1227 if (py_result == nullptr) { 1228 return tensorflow::errors::Internal("gradient function threw exceptions"); 1229 } 1230 result->clear(); 1231 PyObject* seq = 1232 PySequence_Fast(py_result, "expected a sequence of gradients"); 1233 if (seq == nullptr) { 1234 return tensorflow::errors::InvalidArgument( 1235 "gradient function did not return a list"); 1236 } 1237 int len = PySequence_Fast_GET_SIZE(seq); 1238 VLOG(1) << "Gradient length is " << len; 1239 result->reserve(len); 1240 for (int i = 0; i < len; ++i) { 1241 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 1242 if (item == Py_None) { 1243 result->push_back(nullptr); 1244 } else { 1245 Py_INCREF(item); 1246 result->push_back(item); 1247 } 1248 } 1249 Py_DECREF(seq); 1250 Py_DECREF(py_result); 1251 return tensorflow::Status::OK(); 1252 } 1253 1254 void ReleaseBackwardFunction(PyObject* backward_function) const final { 1255 Py_DECREF(backward_function); 1256 } 1257 1258 void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } 1259 1260 private: 1261 PyObject* py_vspace_; 1262 1263 PyObject* num_elements_; 1264 PyObject* aggregate_fn_; 1265 PyObject* zeros_; 1266 PyObject* ones_; 1267 }; 1268 1269 std::vector<PyObject*> MakeTensorList(PyObject* tensors) { 1270 PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); 1271 if (seq == nullptr) { 1272 return {}; 1273 } 1274 int len = PySequence_Fast_GET_SIZE(seq); 1275 std::vector<PyObject*> list; 1276 list.reserve(len); 1277 for (int i = 0; i < len; ++i) { 1278 list.push_back(PySequence_Fast_GET_ITEM(seq, i)); 1279 } 1280 Py_DECREF(seq); 1281 return list; 1282 } 1283 1284 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, 1285 PyObject* target, PyObject* sources, 1286 PyObject* output_gradients, TF_Status* status) { 1287 PyVSpace c_vspace(vspace); 1288 if (!c_vspace.Initialize().ok()) { 1289 return nullptr; 1290 } 1291 1292 std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target); 1293 if (PyErr_Occurred()) { 1294 return nullptr; 1295 } 1296 std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources); 1297 if (PyErr_Occurred()) { 1298 return nullptr; 1299 } 1300 std::vector<PyObject*> outgrad_vec; 1301 if (output_gradients != Py_None) { 1302 outgrad_vec = MakeTensorList(output_gradients); 1303 if (PyErr_Occurred()) { 1304 return nullptr; 1305 } 1306 for (PyObject* tensor : outgrad_vec) { 1307 // Calling the backward function will eat a reference to the tensors in 1308 // outgrad_vec, so we need to increase their reference count. 1309 Py_INCREF(tensor); 1310 } 1311 } 1312 TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); 1313 std::vector<PyObject*> result; 1314 status->status = tape_obj->tape->ComputeGradient( 1315 c_vspace, target_vec, sources_vec, outgrad_vec, &result); 1316 if (!status->status.ok()) { 1317 if (PyErr_Occurred()) { 1318 // Do not propagate the erroneous status as that would swallow the 1319 // exception which caused the problem. 1320 status->status = tensorflow::Status::OK(); 1321 } 1322 return nullptr; 1323 } 1324 if (!result.empty()) { 1325 PyObject* py_result = PyList_New(result.size()); 1326 for (int i = 0; i < result.size(); ++i) { 1327 if (result[i] == nullptr) { 1328 Py_INCREF(Py_None); 1329 result[i] = Py_None; 1330 } 1331 PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i])); 1332 } 1333 return py_result; 1334 } 1335 return PyList_New(0); 1336 } 1337 1338 namespace { 1339 static const int kFastPathExecuteInputStartIndex = 6; 1340 1341 PyObject* GetPythonObjectFromString(const char* s) { 1342 #if PY_MAJOR_VERSION >= 3 1343 return PyUnicode_FromString(s); 1344 #else 1345 return PyBytes_FromString(s); 1346 #endif 1347 } 1348 1349 bool CheckEagerTensors(PyObject* seq, int start_index, 1350 const tensorflow::OpDef& op_def) { 1351 for (int i = 0; i < op_def.input_arg_size(); i++) { 1352 PyObject* item = PyTuple_GET_ITEM(seq, i + start_index); 1353 if (!op_def.input_arg(i).number_attr().empty() || 1354 !op_def.input_arg(i).type_list_attr().empty()) { 1355 // This item should be a list input. 1356 if (!PyList_Check(item)) return false; 1357 for (Py_ssize_t j = 0; j < PyList_Size(item); j++) { 1358 if (!EagerTensor_CheckExact(PyList_GET_ITEM(item, j))) return false; 1359 } 1360 } else if (!EagerTensor_CheckExact(item)) { 1361 return false; 1362 } 1363 } 1364 1365 return true; 1366 } 1367 1368 // Adds input and type attr to the op, and to the list of flattened 1369 // inputs/attrs. 1370 bool AddInputToOp(PyObject* input, const tensorflow::OpDef::ArgDef* input_arg, 1371 std::vector<PyObject*>* flattened_attrs, 1372 std::vector<PyObject*>* flattened_inputs, TFE_Op* op, 1373 TF_Status* status) { 1374 TFE_TensorHandle* input_handle = EagerTensor_Handle(input); 1375 if (input_arg != nullptr && !input_arg->type_attr().empty()) { 1376 auto dtype = TFE_TensorHandleDataType(input_handle); 1377 TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype); 1378 if (flattened_attrs != nullptr) { 1379 flattened_attrs->push_back( 1380 GetPythonObjectFromString(input_arg->type_attr().data())); 1381 flattened_attrs->push_back(PyLong_FromLong(dtype)); 1382 } 1383 } 1384 1385 if (flattened_inputs != nullptr) { 1386 flattened_inputs->push_back(input); 1387 } 1388 TFE_OpAddInput(op, input_handle, status); 1389 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { 1390 return false; 1391 } 1392 return true; 1393 } 1394 1395 const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) { 1396 const char* op_name = TFE_GetPythonString(py_op_name); 1397 if (op_name == nullptr) { 1398 PyErr_SetString(PyExc_TypeError, 1399 Printf("expected a string for op_name, got %s instead", 1400 py_op_name->ob_type->tp_name) 1401 .c_str()); 1402 return nullptr; 1403 } 1404 1405 const tensorflow::OpRegistrationData* op_reg_data = nullptr; 1406 const tensorflow::Status lookup_status = 1407 tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); 1408 if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) { 1409 return nullptr; 1410 } 1411 return &op_reg_data->op_def; 1412 } 1413 1414 const char* GetDeviceName(PyObject* py_device_name) { 1415 if (py_device_name != Py_None) { 1416 return TFE_GetPythonString(py_device_name); 1417 } 1418 return nullptr; 1419 } 1420 1421 bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { 1422 if (!PyList_Check(list)) { 1423 PyErr_SetString(PyExc_TypeError, 1424 Printf("expected a list for attr %s, got %s instead", 1425 attr_name.data(), list->ob_type->tp_name) 1426 .data()); 1427 1428 return false; 1429 } 1430 return true; 1431 } 1432 1433 bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, 1434 const tensorflow::OpDef* op_def, PyObject* args, 1435 const std::vector<PyObject*>& flattened_inputs, 1436 const std::vector<PyObject*>& flattened_attrs, 1437 PyObject* flattened_result, PyObject* op_name, PyObject* name, 1438 PyObject* record_gradient_callback, PyObject* callbacks) { 1439 PyObject* inputs = PyTuple_New(flattened_inputs.size()); 1440 for (int i = 0; i < flattened_inputs.size(); i++) { 1441 PyObject* input = flattened_inputs[i]; 1442 Py_INCREF(input); 1443 PyTuple_SET_ITEM(inputs, i, input); 1444 } 1445 1446 int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - 1447 op_def->input_arg_size() - 1448 kFastPathExecuteInputStartIndex; 1449 int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; 1450 PyObject* attrs = PyTuple_New(num_attrs); 1451 1452 for (int i = 0; i < num_non_inferred_attrs; i++) { 1453 auto* attr = PyTuple_GET_ITEM( 1454 args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i); 1455 Py_INCREF(attr); 1456 PyTuple_SET_ITEM(attrs, i, attr); 1457 } 1458 for (int i = num_non_inferred_attrs; i < num_attrs; i++) { 1459 // Not INCREFing anything in flattened_attrs as each of those is a new 1460 // reference, so allow the attrs tuple to steal the reference. 1461 PyTuple_SET_ITEM(attrs, i, flattened_attrs.at(i - num_non_inferred_attrs)); 1462 } 1463 1464 PyObject* callback_args = 1465 Py_BuildValue("OOOOO", op_name, inputs, attrs, flattened_result, name); 1466 1467 auto cleaner = tensorflow::gtl::MakeCleanup([inputs, attrs, callback_args] { 1468 Py_DECREF(inputs); 1469 Py_DECREF(attrs); 1470 Py_DECREF(callback_args); 1471 }); 1472 1473 if (run_gradient_callback) { 1474 if (!PyCallable_Check(record_gradient_callback)) { 1475 PyErr_SetString(PyExc_TypeError, 1476 Printf("expected a function for " 1477 "record_gradient_callback, got %s instead", 1478 record_gradient_callback->ob_type->tp_name) 1479 .c_str()); 1480 return false; 1481 } 1482 1483 PyObject* callback_result = 1484 PyObject_CallObject(record_gradient_callback, callback_args); 1485 if (!callback_result) { 1486 return false; 1487 } 1488 Py_DECREF(callback_result); 1489 } 1490 1491 if (run_post_exec_callbacks) { 1492 for (Py_ssize_t i = 0; i < PyList_Size(callbacks); i++) { 1493 PyObject* callback_fn = PyList_GET_ITEM(callbacks, i); 1494 if (!PyCallable_Check(callback_fn)) { 1495 PyErr_SetString( 1496 PyExc_TypeError, 1497 Printf("expected a function for " 1498 "post execution callback in index %ld, got %s instead", 1499 i, callback_fn->ob_type->tp_name) 1500 .c_str()); 1501 return false; 1502 } 1503 PyObject* callback_result = 1504 PyObject_CallObject(callback_fn, callback_args); 1505 if (!callback_result) { 1506 return false; 1507 } 1508 Py_DECREF(callback_result); 1509 } 1510 } 1511 1512 return true; 1513 } 1514 1515 } // namespace 1516 1517 PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { 1518 Py_ssize_t args_size = PyTuple_GET_SIZE(args); 1519 if (args_size < kFastPathExecuteInputStartIndex) { 1520 PyErr_SetString( 1521 PyExc_ValueError, 1522 Printf("There must be at least %d items in the input tuple.", 1523 kFastPathExecuteInputStartIndex) 1524 .c_str()); 1525 return nullptr; 1526 } 1527 1528 TFE_Context* ctx = reinterpret_cast<TFE_Context*>( 1529 PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); 1530 const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); 1531 PyObject* op_name = PyTuple_GET_ITEM(args, 2); 1532 const tensorflow::OpDef* op_def = GetOpDef(op_name); 1533 if (op_def == nullptr) return nullptr; 1534 PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3); 1535 PyObject* name = PyTuple_GET_ITEM(args, 4); 1536 PyObject* callbacks = PyTuple_GET_ITEM(args, 5); 1537 1538 if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { 1539 PyErr_SetString( 1540 PyExc_ValueError, 1541 Printf("Tuple size smaller than intended. Expected to be at least %d, " 1542 "was %ld", 1543 kFastPathExecuteInputStartIndex + op_def->input_arg_size(), 1544 args_size) 1545 .c_str()); 1546 return nullptr; 1547 } 1548 1549 if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, *op_def)) { 1550 RaiseFallbackException( 1551 "This function does not handle the case of the path where " 1552 "all inputs are not already EagerTensors."); 1553 return nullptr; 1554 } 1555 1556 TF_Status* status = TF_NewStatus(); 1557 TFE_Op* op = TFE_NewOp(ctx, op_def->name().c_str(), status); 1558 auto cleaner = tensorflow::gtl::MakeCleanup([status, op] { 1559 TF_DeleteStatus(status); 1560 TFE_DeleteOp(op); 1561 }); 1562 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { 1563 return nullptr; 1564 } 1565 1566 // Mapping of attr name to size - used to calculate the number of values 1567 // to be expected by the TFE_Execute run. 1568 tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes; 1569 1570 // Set non-inferred attrs, including setting defaults if the attr is passed in 1571 // as None. 1572 for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size(); 1573 i < args_size; i += 2) { 1574 PyObject* py_attr_name = PyTuple_GET_ITEM(args, i); 1575 const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name)); 1576 PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1); 1577 1578 // Not creating an index since most of the time there are not more than a 1579 // few attrs. 1580 // TODO(nareshmodi): Maybe include the index as part of the 1581 // OpRegistrationData. 1582 for (const auto& attr : op_def->attr()) { 1583 if (attr_name == attr.name()) { 1584 SetOpAttrWithDefaults(ctx, op, attr, attr_name.data(), py_attr_value, 1585 &attr_list_sizes, status); 1586 1587 if (TF_GetCode(status) != TF_OK) { 1588 RaiseFallbackException(TF_Message(status)); 1589 return nullptr; 1590 } 1591 1592 break; 1593 } 1594 } 1595 } 1596 1597 TFE_OpSetDevice(op, device_name, status); 1598 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { 1599 return nullptr; 1600 } 1601 1602 // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks 1603 // (similar to benchmark_tf_gradient_function_*). Also consider using an 1604 // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks 1605 // point out problems with heap allocs. 1606 bool run_gradient_callback = !*ThreadTapeIsStopped() && 1607 !GetTapeSet()->empty() && 1608 record_gradient_callback != Py_None; 1609 bool run_post_exec_callbacks = 1610 callbacks != Py_None && PyList_Size(callbacks) > 0; 1611 bool run_callbacks = run_gradient_callback || run_post_exec_callbacks; 1612 // Flat attrs and inputs as required by the record_gradient call. The attrs 1613 // here only contain inferred attrs (non-inferred attrs are added directly 1614 // from the input args). 1615 // All items in flattened_attrs contain new references. 1616 // All items in flattened_inputs contain borrowed references. 1617 // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work 1618 // directly. 1619 std::unique_ptr<std::vector<PyObject*>> flattened_attrs = nullptr; 1620 std::unique_ptr<std::vector<PyObject*>> flattened_inputs = nullptr; 1621 1622 if (run_callbacks) { 1623 flattened_attrs.reset(new std::vector<PyObject*>); 1624 flattened_inputs.reset(new std::vector<PyObject*>); 1625 } 1626 1627 // Add inferred attrs and inputs. 1628 // The following code might set duplicate type attrs. This will result in 1629 // the CacheKey for the generated AttrBuilder possibly differing from 1630 // those where the type attrs are correctly set. Inconsistent CacheKeys 1631 // for ops means that there might be unnecessarily duplicated kernels. 1632 // TODO(nareshmodi): Fix this. 1633 for (int i = 0; i < op_def->input_arg_size(); i++) { 1634 const auto& input_arg = op_def->input_arg(i); 1635 1636 PyObject* input = 1637 PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); 1638 if (!input_arg.number_attr().empty()) { 1639 // The item is a homogeneous list. 1640 if (!RaiseIfNotPyList(input, input_arg.number_attr())) return nullptr; 1641 Py_ssize_t len = PyList_Size(input); 1642 1643 TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); 1644 if (run_callbacks) { 1645 flattened_attrs->push_back( 1646 GetPythonObjectFromString(input_arg.number_attr().data())); 1647 flattened_attrs->push_back(PyLong_FromLong(len)); 1648 } 1649 attr_list_sizes[input_arg.number_attr()] = len; 1650 1651 if (len > 0) { 1652 // First item adds the type attr. 1653 if (!AddInputToOp(PyList_GET_ITEM(input, 0), &input_arg, 1654 flattened_attrs.get(), flattened_inputs.get(), op, 1655 status)) { 1656 return nullptr; 1657 } 1658 1659 for (Py_ssize_t j = 1; j < len; j++) { 1660 // Since the list is homogeneous, we don't need to re-add the attr. 1661 if (!AddInputToOp(PyList_GET_ITEM(input, j), nullptr /* input_arg */, 1662 nullptr /* flattened_attrs */, 1663 flattened_inputs.get(), op, status)) { 1664 return nullptr; 1665 } 1666 } 1667 } 1668 } else if (!input_arg.type_list_attr().empty()) { 1669 // The item is a heterogeneous list. 1670 if (!RaiseIfNotPyList(input, input_arg.type_list_attr())) return nullptr; 1671 const string& attr_name = input_arg.type_list_attr(); 1672 Py_ssize_t len = PyList_Size(input); 1673 tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); 1674 PyObject* py_attr_value = nullptr; 1675 if (run_callbacks) { 1676 py_attr_value = PyTuple_New(len); 1677 } 1678 for (Py_ssize_t j = 0; j < len; j++) { 1679 PyObject* py_input = PyList_GET_ITEM(input, j); 1680 TFE_TensorHandle* input_handle = EagerTensor_Handle(py_input); 1681 attr_value[j] = TFE_TensorHandleDataType(input_handle); 1682 1683 TFE_OpAddInput(op, input_handle, status); 1684 if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { 1685 return nullptr; 1686 } 1687 1688 if (run_callbacks) { 1689 flattened_inputs->push_back(py_input); 1690 1691 PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j])); 1692 } 1693 } 1694 if (run_callbacks) { 1695 flattened_attrs->push_back(GetPythonObjectFromString(attr_name.data())); 1696 flattened_attrs->push_back(py_attr_value); 1697 } 1698 TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), 1699 attr_value.size()); 1700 attr_list_sizes[attr_name] = len; 1701 } else { 1702 // The item is a single item. 1703 if (!AddInputToOp(input, &input_arg, flattened_attrs.get(), 1704 flattened_inputs.get(), op, status)) { 1705 return nullptr; 1706 } 1707 } 1708 } 1709 1710 int num_retvals = 0; 1711 for (int i = 0; i < op_def->output_arg_size(); i++) { 1712 const auto& output_arg = op_def->output_arg(i); 1713 if (!output_arg.number_attr().empty()) { 1714 num_retvals += attr_list_sizes[output_arg.number_attr()]; 1715 } else if (!output_arg.type_list_attr().empty()) { 1716 num_retvals += attr_list_sizes[output_arg.type_list_attr()]; 1717 } else { 1718 num_retvals++; 1719 } 1720 } 1721 1722 tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); 1723 1724 Py_BEGIN_ALLOW_THREADS; 1725 TFE_Execute(op, retvals.data(), &num_retvals, status); 1726 Py_END_ALLOW_THREADS; 1727 if (TF_GetCode(status) != TF_OK) { 1728 // Augment the status with the op_name for easier debugging similar to 1729 // TFE_Py_Execute. 1730 TF_SetStatus(status, TF_GetCode(status), 1731 tensorflow::strings::StrCat(TF_Message(status), " [Op:", 1732 TFE_GetPythonString(op_name), "]") 1733 .c_str()); 1734 1735 MaybeRaiseExceptionFromTFStatus(status, nullptr); 1736 return nullptr; 1737 } 1738 1739 PyObject* flat_result = PyList_New(num_retvals); 1740 for (int i = 0; i < num_retvals; ++i) { 1741 PyList_SET_ITEM(flat_result, i, EagerTensorFromHandle(retvals[i])); 1742 } 1743 1744 if (run_callbacks && 1745 !RunCallbacks(run_gradient_callback, run_post_exec_callbacks, op_def, 1746 args, *flattened_inputs, *flattened_attrs, flat_result, 1747 op_name, name, record_gradient_callback, callbacks)) { 1748 return nullptr; 1749 } 1750 1751 // Unflatten results. 1752 if (op_def->output_arg_size() == 0) { 1753 Py_RETURN_NONE; 1754 } 1755 1756 if (op_def->output_arg_size() == 1) { 1757 if (!op_def->output_arg(0).number_attr().empty() || 1758 !op_def->output_arg(0).type_list_attr().empty()) { 1759 return flat_result; 1760 } else { 1761 auto* result = PyList_GET_ITEM(flat_result, 0); 1762 Py_INCREF(result); 1763 Py_DECREF(flat_result); 1764 return result; 1765 } 1766 } 1767 1768 // Correctly output the results that are made into a namedtuple. 1769 PyObject* result = PyList_New(op_def->output_arg_size()); 1770 int flat_result_index = 0; 1771 for (int i = 0; i < op_def->output_arg_size(); i++) { 1772 if (!op_def->output_arg(i).number_attr().empty()) { 1773 int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; 1774 PyObject* inner_list = PyList_New(list_length); 1775 for (int j = 0; j < list_length; j++) { 1776 PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); 1777 Py_INCREF(obj); 1778 PyList_SET_ITEM(inner_list, j, obj); 1779 } 1780 PyList_SET_ITEM(result, i, inner_list); 1781 } else if (!op_def->output_arg(i).type_list_attr().empty()) { 1782 int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; 1783 PyObject* inner_list = PyList_New(list_length); 1784 for (int j = 0; j < list_length; j++) { 1785 PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); 1786 Py_INCREF(obj); 1787 PyList_SET_ITEM(inner_list, j, obj); 1788 } 1789 PyList_SET_ITEM(result, i, inner_list); 1790 } else { 1791 PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++); 1792 Py_INCREF(obj); 1793 PyList_SET_ITEM(result, i, obj); 1794 } 1795 } 1796 Py_DECREF(flat_result); 1797 return result; 1798 } 1799