1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <stdlib.h> 17 18 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" 19 #include "tensorflow/python/lib/core/numpy.h" 20 #include "tensorflow/python/lib/core/py_seq_tensor.h" 21 #include "tensorflow/python/lib/core/safe_ptr.h" 22 23 #include "tensorflow/python/eager/pywrap_tensor.h" 24 #include "tensorflow/python/eager/pywrap_tfe.h" 25 26 #include "tensorflow/c/c_api.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/python/lib/core/ndarray_tensor.h" 29 30 namespace { 31 32 TFE_Context* GetContext(PyObject* ctx) { 33 TFE_Context* context = 34 reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr)); 35 if (context == nullptr) { 36 PyErr_SetString(PyExc_TypeError, 37 tensorflow::strings::StrCat( 38 "Expecting a PyCapsule encoded context handle. Got ", 39 Py_TYPE(ctx)->tp_name) 40 .c_str()); 41 } 42 return context; 43 } 44 45 // Convert a Python numpy.ndarray object to a TFE_TensorHandle. 46 // The two may share underlying storage so changes to one may reflect in the 47 // other. 48 TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) { 49 tensorflow::Tensor t; 50 auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); 51 if (cppstatus.ok()) { 52 return TFE_NewTensorHandle(t); 53 } else { 54 PyErr_SetString(PyExc_ValueError, 55 tensorflow::strings::StrCat( 56 "Failed to convert numpy ndarray to a Tensor (", 57 cppstatus.error_message(), ").") 58 .c_str()); 59 return nullptr; 60 } 61 } 62 63 // Casts data referred to by `handle` from type `src_type_enum` to type 64 // `dst_type_enum`. 65 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, 66 TF_DataType src_type_enum, 67 TF_DataType dst_type_enum, TF_Status* out_status) { 68 if (ctx == nullptr) return nullptr; 69 const char* op_name = "Cast"; 70 const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0"; 71 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status); 72 #define RETURN_ERROR \ 73 { \ 74 TFE_DeleteOp(op); \ 75 return nullptr; \ 76 } 77 if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR 78 TFE_OpSetDevice(op, device_name, out_status); 79 if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR 80 TFE_OpAddInput(op, handle, out_status); 81 if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR 82 TFE_OpSetAttrType(op, "SrcT", src_type_enum); 83 TFE_OpSetAttrType(op, "DstT", dst_type_enum); 84 TFE_TensorHandle* output = nullptr; 85 int num_outputs = 1; 86 TFE_Execute(op, &output, &num_outputs, out_status); 87 if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 || 88 output == nullptr) { 89 if (output != nullptr) { 90 TFE_DeleteTensorHandle(output); 91 } 92 RETURN_ERROR 93 } 94 TFE_DeleteOp(op); 95 return output; 96 #undef RETURN_ERROR 97 } 98 99 TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx, 100 PyObject* dev) { 101 const char* device = ""; 102 if (dev != nullptr && dev != Py_None) { 103 device = PyBytes_AsString(dev); 104 #if PY_MAJOR_VERSION >= 3 105 if (device == nullptr) { 106 PyErr_Clear(); 107 device = PyUnicode_AsUTF8(dev); 108 } 109 #endif 110 if (device == nullptr) { 111 PyErr_SetString(PyExc_TypeError, 112 "Error parsing device argument to CopyToDevice"); 113 return nullptr; 114 } 115 } 116 TFE_Context* context = GetContext(ctx); 117 if (context == nullptr) { // PyErr already set by GetContext 118 return nullptr; 119 } 120 auto status = tensorflow::make_safe(TF_NewStatus()); 121 TFE_TensorHandle* new_handle = 122 TFE_TensorHandleCopyToDevice(handle, context, device, status.get()); 123 if (TF_GetCode(status.get()) != TF_OK) { 124 PyErr_SetString( 125 PyExc_RuntimeError, 126 tensorflow::strings::StrCat("Error copying tensor to device: ", device, 127 ". ", TF_Message(status.get())) 128 .c_str()); 129 return nullptr; 130 } 131 return new_handle; 132 } 133 134 // Helper function to convert `v` to an int and store it in `*out`. Returns true 135 // on success, false otherwise. 136 // Note that we assume that v is a python int (not long) representing a 137 // TF_DataType value. 138 bool PyIntToDataType(PyObject* v, int* out) { 139 #if PY_MAJOR_VERSION < 3 140 if (PyInt_Check(v)) { 141 *out = PyInt_AS_LONG(v); 142 return true; 143 } 144 #else 145 if (PyLong_Check(v)) { 146 *out = PyLong_AsLong(v); 147 return true; 148 } 149 #endif 150 return false; 151 } 152 153 // Helper function to create a python integer from TF_DataType. 154 PyObject* PyIntFromDataType(TF_DataType l) { 155 #if PY_MAJOR_VERSION < 3 156 return PyInt_FromLong(l); 157 #else 158 return PyLong_FromLong(l); 159 #endif 160 } 161 162 } // namespace 163 164 extern "C" { 165 166 static const int kMaxEagerTensorParentSize = 32; 167 168 // TODO(agarwal): store context handle in EagerTensor. 169 typedef struct EagerTensor { 170 PyObject_HEAD; 171 // Note that we leave kMaxEagerTensorParentSize bytes here for use by the 172 // parent class. The parent class is set at runtime, so we don't know the 173 // exact size at compile time. 174 char unused[kMaxEagerTensorParentSize]; 175 TFE_TensorHandle* handle; 176 int64_t id; 177 // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will 178 // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE 179 // tensors, this will contain a serialized HandleData proto with shape 180 // inference metadata about shapes and dtypes of resources accessible from 181 // this handle. 182 // Note that we assume that handle_data cannot participate in reference 183 // cycles, and hence don't provide GC support for it. 184 PyObject* handle_data; 185 186 // This stores `_keras_mask` object and is set by Tensorflow layers. 187 PyObject* keras_mask; 188 } EagerTensor; 189 190 // tp_init for EagerTensor. 191 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { 192 self->id = get_uid(); 193 self->handle = nullptr; 194 Py_INCREF(Py_None); 195 self->handle_data = Py_None; 196 Py_INCREF(Py_None); 197 self->keras_mask = Py_None; 198 PyObject* value; 199 PyObject* context = nullptr; 200 PyObject* device = nullptr; 201 PyObject* dtype = Py_None; 202 const char* kwlist[] = {"value", "context", "device", "dtype", nullptr}; 203 if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O", 204 const_cast<char**>(kwlist), &value, &context, 205 &device, &dtype)) { 206 return -1; 207 } 208 // Extract dtype 209 int desired_dtype = -1; 210 if (dtype != Py_None) { 211 if (!PyIntToDataType(dtype, &desired_dtype)) { 212 PyErr_SetString(PyExc_TypeError, 213 tensorflow::strings::StrCat( 214 "Expecting a DataType value for dtype. Got ", 215 Py_TYPE(dtype)->tp_name) 216 .c_str()); 217 return -1; 218 } 219 } 220 tensorflow::Safe_TFE_TensorHandlePtr handle = 221 tensorflow::make_safe(static_cast<TFE_TensorHandle*>(nullptr)); 222 PyErr_Clear(); 223 if (PyArray_Check(value)) { 224 int desired_np_dtype = -1; 225 if (desired_dtype >= 0) { 226 if (!tensorflow::TF_DataType_to_PyArray_TYPE( 227 static_cast<TF_DataType>(desired_dtype), &desired_np_dtype) 228 .ok()) { 229 PyErr_SetString(PyExc_TypeError, 230 tensorflow::strings::StrCat( 231 "Invalid dtype argument value ", desired_dtype) 232 .c_str()); 233 return -1; 234 } 235 } 236 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value); 237 int current_np_dtype = PyArray_TYPE(array); 238 auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr)); 239 if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) || 240 !PyArray_ISCARRAY(array)) { 241 int new_dtype = 242 desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype; 243 safe_value = tensorflow::make_safe( 244 PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0, 245 NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr)); 246 if (PyErr_Occurred()) return -1; 247 if (safe_value == nullptr) { 248 PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value"); 249 return -1; 250 } 251 value = safe_value.get(); 252 } 253 handle = tensorflow::make_safe(NumpyToTensorHandle(value)); 254 } else { 255 tensorflow::Tensor t; 256 // TODO(josh11b): Have PySeqToTensor set python errors instead of 257 // returning Status. 258 auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t); 259 if (!cppstatus.ok()) { 260 PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str()); 261 return -1; 262 } 263 handle = tensorflow::make_safe(TFE_NewTensorHandle(t)); 264 } 265 if (PyErr_Occurred()) return -1; 266 if (handle == nullptr) { 267 PyErr_SetString(PyExc_ValueError, "Error while creating an EagerTensor"); 268 return -1; 269 } 270 TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); 271 if (desired_dtype >= 0 && desired_dtype != handle_dtype) { 272 auto out_status = tensorflow::make_safe(TF_NewStatus()); 273 handle = tensorflow::make_safe( 274 EagerCast(GetContext(context), handle.get(), handle_dtype, 275 static_cast<TF_DataType>(desired_dtype), out_status.get())); 276 if (TF_GetCode(out_status.get()) != TF_OK) { 277 PyErr_SetString( 278 PyExc_ValueError, 279 tensorflow::strings::StrCat("Error while casting from DataType ", 280 handle_dtype, " to ", desired_dtype, ". ", 281 TF_Message(out_status.get())) 282 .c_str()); 283 return -1; 284 } 285 handle_dtype = TFE_TensorHandleDataType(handle.get()); 286 } 287 288 // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host 289 // memory. We approximate the same behavior for eager execution - keeping 290 // int32 tensors in host memory. 291 // 292 // We do so to preclude the need for callers into such kernels from having to 293 // explicitly place the int32 tensors in host memory. For example, without 294 // this, one needed: 295 // 296 // with tf.device('/gpu:0'): 297 // ...// code here 298 // with tf.device('/cpu:0'): 299 // shape = tf.constant(...) 300 // y = tf.random_uniform(shape) 301 // 302 // Without the CPU device block, tfe.ops.random_uniform would fail since the 303 // kernel expects the shape in host memory. 304 // 305 // With this support, we simplify the code: 306 // 307 // with tf.device('/gpu:0'): 308 // y = tf.random_uniform(...) 309 // 310 // The approximation is not exact there are GPU kernels which do not require 311 // host memory for int32 tensors. This will lead to a discrepancy between 312 // eager and graph execution. 313 // TODO(ashankar): Fix this. 314 if (handle_dtype != TF_INT32) { 315 // Note that this is a shallow copy and will share the underlying buffer 316 // if copying to the same device. 317 handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device)); 318 if (handle == nullptr) return -1; 319 } 320 self->handle = handle.release(); 321 return 0; 322 } 323 324 // tp_dealloc for EagerTensor. 325 void EagerTensor_dealloc(EagerTensor* self) { 326 Py_DECREF(self->handle_data); 327 Py_DECREF(self->keras_mask); 328 TFE_DeleteTensorHandle(self->handle); 329 self->handle = nullptr; 330 // We have the global interpreter lock, so use this chance to perform delayed 331 // refcount decrements. 332 tensorflow::ClearDecrefCache(); 333 auto id = self->id; 334 Py_TYPE(self)->tp_free(self); 335 TFE_Py_TapeSetDeleteTrace(id); 336 } 337 338 // Getter for `_id`. 339 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) { 340 return PyLong_FromLongLong(self->id); 341 } 342 343 // Getter for `_datatype_enum`. 344 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { 345 return PyIntFromDataType(TFE_TensorHandleDataType(self->handle)); 346 } 347 348 // Getter for `_shape_tuple`. 349 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { 350 auto handle = self->handle; 351 int n = TFE_TensorHandleNumDims(handle); 352 PyObject* shape = PyTuple_New(n); 353 if (PyErr_Occurred()) return nullptr; 354 for (int i = 0; i < n; ++i) { 355 PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i)); 356 if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { 357 Py_DECREF(shape); 358 if (dim != nullptr) Py_DECREF(dim); 359 PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); 360 return nullptr; 361 } 362 } 363 return shape; 364 } 365 366 // Getter for `_rank`. 367 static PyObject* EagerTensor_rank(EagerTensor* self) { 368 #if PY_MAJOR_VERSION < 3 369 return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle)); 370 #else 371 return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle)); 372 #endif 373 } 374 375 static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { 376 Py_INCREF(self->handle_data); 377 return self->handle_data; 378 } 379 380 static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value, 381 void* unused) { 382 Py_DECREF(self->handle_data); 383 Py_INCREF(value); 384 self->handle_data = value; 385 return 0; 386 } 387 388 static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) { 389 Py_INCREF(self->keras_mask); 390 return self->keras_mask; 391 } 392 393 static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value, 394 void* unused) { 395 Py_DECREF(self->keras_mask); 396 Py_INCREF(value); 397 self->keras_mask = value; 398 return 0; 399 } 400 // Function `_copy_to_device`. 401 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, 402 PyObject* kwds) { 403 const char* kwlist[] = {"context", "device", nullptr}; 404 PyObject* ctx = nullptr; 405 PyObject* dev = nullptr; 406 if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist), 407 &ctx, &dev) || 408 !ctx || !dev) { 409 return nullptr; 410 } 411 auto handle = CopyToDevice(self->handle, ctx, dev); 412 return EagerTensorFromHandle(handle); 413 } 414 415 // Function `_numpy`. 416 // Convert an EagerTensor to a Python numpy.ndarray object. 417 // The two may share underlying storage so changes to one may reflect in the 418 // other. 419 // Note that if `self` is not on CPU, we raise an Exception. 420 static PyObject* EagerTensor_numpy(EagerTensor* self) { 421 auto status = tensorflow::make_safe(TF_NewStatus()); 422 const tensorflow::Tensor* t = 423 TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get()); 424 if (TF_GetCode(status.get()) != TF_OK) { 425 PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get())); 426 return nullptr; 427 } 428 PyObject* ret = nullptr; 429 auto cppstatus = tensorflow::TensorToNdarray(*t, &ret); 430 if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) { 431 Py_XDECREF(ret); 432 return nullptr; 433 } else { 434 return ret; 435 } 436 } 437 438 // Getter `device`. 439 static PyObject* EagerTensor_device(EagerTensor* self) { 440 #if PY_MAJOR_VERSION >= 3 441 return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle)); 442 #else 443 return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle)); 444 #endif 445 } 446 447 static PyGetSetDef EagerTensor_getseters[] = { 448 {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr, 449 const_cast<char*>("_id"), nullptr}, 450 {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr, 451 const_cast<char*>("device"), nullptr}, 452 {const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle, 453 (setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"), 454 nullptr}, 455 {const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask, 456 (setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"), 457 nullptr}, 458 {nullptr} /* Sentinel */ 459 }; 460 461 static PyMethodDef EagerTensor_methods[] = { 462 {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS, 463 PyDoc_STR("_numpy")}, 464 {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS, 465 PyDoc_STR("_datatype_enum")}, 466 {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS, 467 PyDoc_STR("_shape_tuple")}, 468 {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")}, 469 {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device, 470 METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")}, 471 {nullptr, nullptr}, 472 }; 473 474 // Note that here we are trying to dynamically create a new class as a subclass 475 // of a "HEAPTYPE" class that is itself created in python code and passed in at 476 // runtime. This is fairly atypical and undocumented. 477 // 478 // We use the following strategy for this. Unfortunately, we have to use 479 // different approaches for python2.x vs python3.x 480 // For python2.x, we create the class as a static type and set its tp_base to 481 // the passed in type. Unfortunately setting tp_flags to include 482 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more 483 // initialization of the underlying PyHeapTypeObject and not doing that leads to 484 // some random crashes especially during garbage collection. 485 // python3.x explicitly disables a static subclass of a HEAPTYPE base class. 486 // However it provides a new function, PyType_FromSpecWithBases, to create 487 // types dynamically. 488 489 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor. 490 PyTypeObject* EagerTensorType = nullptr; 491 492 #if PY_MAJOR_VERSION >= 3 493 static PyType_Slot EagerTensor_Type_slots[] = { 494 Py_tp_dealloc, 495 reinterpret_cast<void*>(EagerTensor_dealloc), 496 Py_tp_methods, 497 reinterpret_cast<void*>(EagerTensor_methods), 498 Py_tp_getset, 499 reinterpret_cast<void*>(EagerTensor_getseters), 500 Py_tp_init, 501 reinterpret_cast<void*>(EagerTensor_init), 502 0, 503 nullptr, 504 }; 505 506 PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0, 507 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, 508 EagerTensor_Type_slots}; 509 #else 510 // TODO(agarwal): support active_trace. 511 static PyTypeObject _EagerTensorType = { 512 // clang-format off 513 PyVarObject_HEAD_INIT(nullptr, 0) 514 // clang-format on 515 "EagerTensor", /* tp_name */ 516 sizeof(EagerTensor), /* tp_basicsize */ 517 0, /* tp_itemsize */ 518 (destructor)EagerTensor_dealloc, /* tp_dealloc */ 519 nullptr, /* tp_print */ 520 nullptr, /* tp_getattr */ 521 nullptr, /* tp_setattr */ 522 nullptr, /* tp_compare */ 523 nullptr, /* tp_repr */ 524 nullptr, /* tp_as_number */ 525 nullptr, /* tp_as_sequence */ 526 nullptr, /* tp_as_mapping */ 527 nullptr, /* tp_hash */ 528 nullptr, /* tp_call */ 529 nullptr, /* tp_str */ 530 nullptr, /* tp_getattro */ 531 nullptr, /* tp_setattro */ 532 nullptr, /* tp_as_buffer */ 533 Py_TPFLAGS_DEFAULT, /* tp_flags */ 534 nullptr, /* tp_doc */ 535 nullptr, /* tp_traverse */ 536 nullptr, /* tp_clear */ 537 nullptr, /* tp_richcompare */ 538 0, /* tp_weaklistoffset */ 539 nullptr, /* tp_iter */ 540 nullptr, /* tp_iternext */ 541 EagerTensor_methods, /* tp_methods */ 542 nullptr, /* tp_members */ 543 EagerTensor_getseters, /* tp_getset */ 544 nullptr, /* tp_base */ 545 nullptr, /* tp_dict */ 546 nullptr, /* tp_descr_get */ 547 nullptr, /* tp_descr_set */ 548 0, /* tp_dictoffset */ 549 (initproc)EagerTensor_init, /* tp_init */ 550 nullptr, /* tp_alloc */ 551 nullptr, /* tp_new */ 552 }; 553 554 #endif 555 556 } // extern "C" 557 558 bool EagerTensor_CheckExact(const PyObject* o) { 559 return Py_TYPE(o) == EagerTensorType; 560 } 561 562 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) { 563 return reinterpret_cast<const EagerTensor*>(o)->handle; 564 } 565 566 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { 567 if (handle == nullptr) { 568 return nullptr; 569 } 570 EagerTensor* t = reinterpret_cast<EagerTensor*>( 571 EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None)); 572 if (t != nullptr) { 573 t->id = get_uid(); 574 Py_INCREF(Py_None); 575 t->handle_data = Py_None; 576 Py_INCREF(Py_None); 577 t->keras_mask = Py_None; 578 t->handle = handle; 579 } 580 return reinterpret_cast<PyObject*>(t); 581 } 582 583 tensorflow::int64 EagerTensor_id(const PyObject* tensor) { 584 CHECK(EagerTensor_CheckExact(tensor)); 585 return reinterpret_cast<const EagerTensor*>(tensor)->id; 586 } 587 588 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { 589 if (!PyType_Check(base_class)) { 590 PyErr_SetString( 591 PyExc_TypeError, 592 tensorflow::strings::StrCat( 593 "Expecting a class definition for `base_class` passed to ", 594 "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name) 595 .c_str()); 596 return nullptr; 597 } 598 // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in 599 // EagerTensor to allow for the space usage of the base class. 600 PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class); 601 if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) { 602 PyErr_SetString( 603 PyExc_TypeError, 604 tensorflow::strings::StrCat( 605 "Unable to create subclass EagerTensor from base class ", 606 Py_TYPE(base_class)->tp_name, 607 ". Need its size to be <= ", kMaxEagerTensorParentSize) 608 .c_str()); 609 return nullptr; 610 } 611 if (base_class_type->tp_itemsize != 0) { 612 PyErr_SetString( 613 PyExc_TypeError, 614 tensorflow::strings::StrCat( 615 "Unable to create subclass EagerTensor from base class ", 616 Py_TYPE(base_class)->tp_name, 617 " which supports variable length instances.") 618 .c_str()); 619 return nullptr; 620 } 621 Py_INCREF(base_class); 622 #if PY_MAJOR_VERSION >= 3 623 PyObject* bases = PyTuple_New(1); 624 PyTuple_SET_ITEM(bases, 0, base_class); 625 EagerTensorType = reinterpret_cast<PyTypeObject*>( 626 PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases)); 627 if (PyErr_Occurred()) { 628 return nullptr; 629 } 630 if (EagerTensorType == nullptr) { 631 PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType"); 632 return nullptr; 633 } 634 #else 635 _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class); 636 637 if (PyType_Ready(&_EagerTensorType) < 0) { 638 if (PyErr_Occurred()) return nullptr; 639 PyErr_SetString(PyExc_RuntimeError, 640 "Error while creating EagerTensor type."); 641 return nullptr; 642 } 643 EagerTensorType = &_EagerTensorType; 644 Py_INCREF(EagerTensorType); 645 #endif 646 // We disable instance based attribute lookup. Its not clear if these 647 // dictionaries are correctly initialized in the first place. 648 EagerTensorType->tp_dictoffset = 0; 649 return reinterpret_cast<PyObject*>(EagerTensorType); 650 } 651 652 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) { 653 if (!PyList_Check(tensor_list)) { 654 PyErr_SetString(PyExc_TypeError, 655 tensorflow::strings::StrCat( 656 "tensor_list argument must be a list. Got \"", 657 Py_TYPE(tensor_list)->tp_name, "\"") 658 .c_str()); 659 return nullptr; 660 } 661 if (slice_dim < 0) { 662 PyErr_SetString( 663 PyExc_ValueError, 664 tensorflow::strings::StrCat("Slice dimension must be non-negative. " 665 "Got ", 666 slice_dim) 667 .c_str()); 668 return nullptr; 669 } 670 671 Py_ssize_t num_tensors = PyList_Size(tensor_list); 672 int64_t num_tensors_int = static_cast<int64_t>(num_tensors); 673 auto tensor = tensorflow::make_safe(TF_AllocateTensor( 674 TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int)); 675 int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get())); 676 for (Py_ssize_t i = 0; i < num_tensors; ++i) { 677 PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i); 678 if (!EagerTensor_CheckExact(tensor_obj)) { 679 PyErr_SetString(PyExc_TypeError, 680 tensorflow::strings::StrCat( 681 "Expected a list of EagerTensors but " 682 "element ", 683 i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"") 684 .c_str()); 685 return nullptr; 686 } 687 688 EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj); 689 TFE_TensorHandle* handle = t->handle; 690 if (slice_dim >= TFE_TensorHandleNumDims(handle)) { 691 PyErr_SetString(PyExc_IndexError, 692 tensorflow::strings::StrCat( 693 "Slice dimension (", slice_dim, 694 ") must be smaller than rank of all " 695 "tensors, but tensor at index ", 696 i, " has rank ", TFE_TensorHandleNumDims(handle)) 697 .c_str()); 698 return nullptr; 699 } 700 int64_t dim = TFE_TensorHandleDim(handle, slice_dim); 701 data[i] = dim; 702 } 703 704 auto status = tensorflow::make_safe(TF_NewStatus()); 705 TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get()); 706 if (TF_GetCode(status.get()) != TF_OK) { 707 PyErr_SetString( 708 PyExc_RuntimeError, 709 tensorflow::strings::StrCat("Failed to construct new tensor handle: ", 710 TF_Message(status.get())) 711 .c_str()); 712 return nullptr; 713 } 714 715 return EagerTensorFromHandle(handle); 716 } 717