Home | History | Annotate | Download | only in core
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/python/lib/core/py_seq_tensor.h"
     17 
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 #include "tensorflow/core/framework/types.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/core/stringpiece.h"
     23 #include "tensorflow/core/platform/types.h"
     24 #include "tensorflow/python/lib/core/numpy.h"
     25 #include "tensorflow/python/lib/core/py_util.h"
     26 #include "tensorflow/python/lib/core/safe_ptr.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 inline bool PyIsInstance(PyObject* obj, PyTypeObject* t) {
     32   return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(t));
     33 }
     34 
     35 inline PyObject* PyType(PyObject* obj) {
     36   return reinterpret_cast<PyObject*>(obj->ob_type);
     37 }
     38 
     39 bool IsPyString(PyObject* obj) {
     40   return PyBytes_Check(obj) || PyUnicode_Check(obj);
     41 }
     42 
     43 bool IsPyInt(PyObject* obj) {
     44 #if PY_MAJOR_VERSION >= 3
     45   return PyLong_Check(obj) ||
     46          PyIsInstance(obj, &PyIntegerArrType_Type);  // NumPy integers
     47 #else
     48   return PyInt_Check(obj) || PyLong_Check(obj) ||
     49          PyIsInstance(obj, &PyIntegerArrType_Type);  // NumPy integers
     50 #endif
     51 }
     52 
     53 bool IsPyFloat(PyObject* obj) {
     54   return PyFloat_Check(obj) ||
     55          PyIsInstance(obj, &PyFloatingArrType_Type);  // NumPy float types
     56 }
     57 
     58 // Converts Python object `c` that should hold a Python string into a
     59 // C++ string in *out.  Returns nullptr on success, or a message on error.
     60 // Defined below, but forward declared here for use in PyRepr.
     61 const char* ConvertOneString(PyObject* v, string* out);
     62 
     63 string PyRepr(PyObject* obj) {
     64   if (obj == nullptr) {
     65     return "<null>";
     66   }
     67   Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
     68   if (repr_obj) {
     69     string repr_str;
     70     if (ConvertOneString(repr_obj.get(), &repr_str) == nullptr) {
     71       return repr_str;
     72     }
     73   }
     74   return "<error computing repr()>";
     75 }
     76 
     77 bool IsPyDimension(PyObject* obj) {
     78   const char* tp_name = obj->ob_type->tp_name;
     79   if (strcmp(tp_name, "Dimension") != 0) return false;
     80   bool ret =
     81       StringPiece(PyRepr(PyType(obj)))
     82           .ends_with("tensorflow.python.framework.tensor_shape.Dimension'>");
     83   return ret;
     84 }
     85 
     86 Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
     87   while (true) {
     88     // We test strings first, in case a string is considered a sequence.
     89     if (IsPyString(obj)) {
     90       *dtype = DT_STRING;
     91     } else if (PySequence_Check(obj)) {
     92       auto length = PySequence_Length(obj);
     93       if (length > 0) {
     94         shape->AddDim(length);
     95         obj = PySequence_GetItem(obj, 0);
     96         continue;
     97       } else if (length == 0) {
     98         shape->AddDim(length);
     99         *dtype = DT_INVALID;  // Invalid dtype for empty tensors.
    100       } else {
    101         // The sequence does not have a valid length (PySequence_Length < 0).
    102         if (PyErr_Occurred()) {
    103           // PySequence_Length failed and set an exception. Fetch the message
    104           // and convert it to a failed status.
    105           return errors::InvalidArgument(PyExceptionFetch());
    106         } else {
    107           // This is almost certainly dead code: PySequence_Length failed but
    108           // did not set an exception.
    109           return errors::InvalidArgument(
    110               "Attempted to convert an invalid sequence to a Tensor.");
    111         }
    112       }
    113     } else if (IsPyFloat(obj)) {
    114       *dtype = DT_DOUBLE;
    115     } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
    116       // Have to test for bool before int, since IsInt(True/False) == true.
    117       *dtype = DT_BOOL;
    118     } else if (IsPyInt(obj)) {
    119       *dtype = DT_INT64;
    120     } else if (IsPyDimension(obj)) {
    121       *dtype = DT_INT64;
    122     } else if (PyComplex_Check(obj) ||
    123                PyIsInstance(obj, &PyComplexFloatingArrType_Type)) {  // NumPy
    124       *dtype = DT_COMPLEX128;
    125     } else {
    126       return errors::InvalidArgument("Attempt to convert a value (",
    127                                      PyRepr(obj),
    128                                      ") with an unsupported type (",
    129                                      PyRepr(PyType(obj)), ") to a Tensor.");
    130     }
    131     return Status::OK();
    132   }
    133 }
    134 
    135 // Error messages
    136 
    137 const char ErrorConverting[] =
    138     "Error while converting Python sequence to Tensor.";
    139 const char ErrorRectangular[] =
    140     "Can't convert non-rectangular Python sequence to Tensor.";
    141 const char ErrorMixedTypes[] =
    142     "Can't convert Python sequence with mixed types to Tensor.";
    143 const char ErrorOutOfRange[] =
    144     "Can't convert Python sequence with out-of-range integer to Tensor.";
    145 const char ErrorOutOfRangeDouble[] =
    146     "Can't convert Python sequence with a value out of range for a "
    147     "double-precision float.";
    148 const char ErrorConvertingUnicodeString[] =
    149     "Error converting unicode string while converting Python sequence to "
    150     "Tensor.";
    151 const char ErrorFoundInt64[] =
    152     "Can't convert Python sequence with out-of-range integer to int32 Tensor.";
    153 const char ErrorFoundFloat[] =
    154     "Can't convert Python sequence with floating point values to integer "
    155     "Tensor.";
    156 
    157 // Template for defining a function for recursively convering obj into
    158 // an array of TYPE using the conversion function CONVERT.
    159 // Note that these helper functions require shape.dims() >= 1.
    160 
    161 #define DEFINE_HELPER(FUNCTION, TYPE, TYPE_ENUM, CONVERT)                 \
    162   const char* FUNCTION##Helper(PyObject* obj, const TensorShape& shape,   \
    163                                TYPE** buf) {                              \
    164     if (TF_PREDICT_FALSE(obj == nullptr)) {                               \
    165       return ErrorConverting;                                             \
    166     }                                                                     \
    167     if (shape.dims() > 1) {                                               \
    168       /* Iterate over outer dim, and recursively convert each element. */ \
    169       const int64 s = shape.dim_size(0);                                  \
    170       if (TF_PREDICT_FALSE(s != PySequence_Length(obj))) {                \
    171         return ErrorRectangular;                                          \
    172       }                                                                   \
    173       TensorShape rest = shape;                                           \
    174       rest.RemoveDim(0);                                                  \
    175       for (int64 i = 0; i < s; ++i) {                                     \
    176         const char* error =                                               \
    177             FUNCTION##Helper(PySequence_GetItem(obj, i), rest, buf);      \
    178         if (TF_PREDICT_FALSE(error != nullptr)) return error;             \
    179       }                                                                   \
    180     } else {                                                              \
    181       Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));         \
    182       if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;      \
    183       const int64 s = shape.dim_size(0);                                  \
    184       if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {   \
    185         return ErrorRectangular;                                          \
    186       }                                                                   \
    187       PyObject** l = PySequence_Fast_ITEMS(seq.get());                    \
    188       for (int64 i = 0; i < s; ++i) {                                     \
    189         const char* error = CONVERT(l[i], *buf);                          \
    190         if (TF_PREDICT_FALSE(error != nullptr)) return error;             \
    191         ++*buf;                                                           \
    192       }                                                                   \
    193     }                                                                     \
    194     return nullptr;                                                       \
    195   }                                                                       \
    196   const char* FUNCTION(PyObject* obj, const TensorShape& shape,           \
    197                        Tensor* dest) {                                    \
    198     /* TODO(josh11b): Allocator & attributes? */                            \
    199     Tensor result(TYPE_ENUM, shape);                                      \
    200     if (shape.dims() == 0) { /* Scalar case */                            \
    201       TYPE value;                                                         \
    202       const char* error = CONVERT(obj, &value);                           \
    203       if (error != nullptr) return error;                                 \
    204       result.scalar<TYPE>()() = value;                                    \
    205     } else {                                                              \
    206       TYPE* buf = result.flat<TYPE>().data();                             \
    207       const char* error = FUNCTION##Helper(obj, shape, &buf);             \
    208       if (error != nullptr) return error;                                 \
    209     }                                                                     \
    210     *dest = result;                                                       \
    211     return nullptr;                                                       \
    212   }
    213 
    214 // Int support
    215 
    216 const char* ConvertOneInt64(PyObject* v, int64* out) {
    217 #if PY_MAJOR_VERSION < 3
    218   if (TF_PREDICT_TRUE(PyInt_Check(v))) {
    219     *out = PyInt_AS_LONG(v);
    220     return nullptr;
    221   }
    222 #endif
    223   if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
    224     int overflow = 0;
    225     // Have to use LongLong for 64 bits, since long is 32 bits on Windows.
    226     *out = PyLong_AsLongLongAndOverflow(v, &overflow);
    227     if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
    228     return nullptr;
    229   }
    230   if (PyIsInstance(v, &PyIntegerArrType_Type)) {  // NumPy integers
    231 #if PY_MAJOR_VERSION < 3
    232     Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
    233 #else
    234     Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
    235 #endif
    236     return ConvertOneInt64(as_int.get(), out);
    237   }
    238   if (IsPyFloat(v)) return ErrorFoundFloat;
    239   return ErrorMixedTypes;
    240 }
    241 
    242 DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64);
    243 
    244 const char* ConvertOneInt32(PyObject* v, int32* out) {
    245   int64 i;
    246 #if PY_MAJOR_VERSION < 3
    247   if (TF_PREDICT_TRUE(PyInt_Check(v))) {
    248     i = PyInt_AS_LONG(v);
    249   } else
    250 #endif
    251       if (PyLong_Check(v) || IsPyDimension(v)) {
    252     int overflow = 0;
    253     // Have to use LongLong for 64 bits, since long is 32 bits on Windows.
    254     i = PyLong_AsLongLongAndOverflow(v, &overflow);
    255     if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
    256   } else if (PyIsInstance(v, &PyIntegerArrType_Type)) {  // NumPy integers
    257 #if PY_MAJOR_VERSION < 3
    258     Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
    259 #else
    260     Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
    261 #endif
    262     return ConvertOneInt32(as_int.get(), out);
    263   } else if (IsPyFloat(v)) {
    264     return ErrorFoundFloat;
    265   } else {
    266     return ErrorMixedTypes;
    267   }
    268   *out = static_cast<uint32>(static_cast<uint64>(i));
    269   // Check for 32-bit overflow.
    270   if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
    271   return nullptr;
    272 }
    273 
    274 DEFINE_HELPER(ConvertInt32, int32, DT_INT32, ConvertOneInt32);
    275 
    276 // Floating-point support
    277 
    278 template <class T>
    279 const char* ConvertOneFloat(PyObject* v, T* out) {
    280   if (TF_PREDICT_TRUE(PyFloat_Check(v))) {
    281     *out = PyFloat_AS_DOUBLE(v);
    282     return nullptr;
    283   }
    284 #if PY_MAJOR_VERSION < 3
    285   if (PyInt_Check(v)) {
    286     *out = PyInt_AS_LONG(v);
    287     return nullptr;
    288   }
    289 #endif
    290   if (PyLong_Check(v)) {
    291     *out = PyLong_AsDouble(v);
    292     if (PyErr_Occurred()) return ErrorOutOfRangeDouble;
    293     return nullptr;
    294   }
    295   if (PyIsInstance(v, &PyFloatingArrType_Type)) {  // NumPy float types
    296     Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
    297     return ConvertOneFloat<T>(as_float.get(), out);
    298   }
    299   if (PyIsInstance(v, &PyIntegerArrType_Type)) {  // NumPy integers
    300 #if PY_MAJOR_VERSION < 3
    301     Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
    302 #else
    303     Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
    304 #endif
    305     return ConvertOneFloat<T>(as_int.get(), out);
    306   }
    307   return ErrorMixedTypes;
    308 }
    309 
    310 DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
    311 DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
    312 
    313 // String support
    314 
    315 const char* ConvertOneString(PyObject* v, string* out) {
    316   if (PyBytes_Check(v)) {
    317     out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
    318     return nullptr;
    319   }
    320   if (PyUnicode_Check(v)) {
    321 #if PY_MAJOR_VERSION >= 3
    322     Py_ssize_t size;
    323     const char* str = PyUnicode_AsUTF8AndSize(v, &size);
    324     if (str == nullptr) return ErrorConvertingUnicodeString;
    325     out->assign(str, size);
    326     return nullptr;
    327 #else
    328     PyObject* py_str = PyUnicode_AsUTF8String(v);
    329     if (py_str == nullptr) return ErrorConvertingUnicodeString;
    330     out->assign(PyBytes_AS_STRING(py_str), PyBytes_GET_SIZE(py_str));
    331     Py_DECREF(py_str);
    332     return nullptr;
    333 #endif
    334   }
    335   return ErrorMixedTypes;
    336 }
    337 
    338 DEFINE_HELPER(ConvertString, string, DT_STRING, ConvertOneString);
    339 
    340 // Complex support
    341 
    342 const char* ConvertOneComplex(PyObject* v, complex128* out) {
    343   if (PyComplex_Check(v)) {
    344     *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
    345     return nullptr;
    346   } else if (PyIsInstance(v, &PyComplexFloatingArrType_Type)) {  // NumPy
    347     auto as_complex = PyComplex_AsCComplex(v);
    348     *out = complex128(as_complex.real, as_complex.imag);
    349     return nullptr;
    350   }
    351   return ErrorMixedTypes;
    352 }
    353 
    354 DEFINE_HELPER(ConvertComplex, complex128, DT_COMPLEX128, ConvertOneComplex);
    355 
    356 // Bool support
    357 
    358 const char* ConvertOneBool(PyObject* v, bool* out) {
    359   if (v == Py_True) {
    360     *out = true;
    361   } else if (v == Py_False) {
    362     *out = false;
    363   } else if (PyIsInstance(v, &PyBoolArrType_Type)) {  // NumPy
    364     *out = PyObject_IsTrue(v);
    365   } else {
    366     return ErrorMixedTypes;
    367   }
    368   return nullptr;
    369 }
    370 
    371 DEFINE_HELPER(ConvertBool, bool, DT_BOOL, ConvertOneBool);
    372 
    373 #undef DEFINE_HELPER
    374 
    375 }  // namespace
    376 
    377 #define RETURN_STRING_AS_STATUS(...)                             \
    378   do {                                                           \
    379     const char* _error = (__VA_ARGS__);                          \
    380     if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \
    381     return errors::InvalidArgument(_error);                      \
    382   } while (0)
    383 
    384 Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
    385   DataType infer_dtype;
    386   TensorShape shape;
    387   TF_RETURN_IF_ERROR(InferShapeAndType(obj, &shape, &infer_dtype));
    388   DataType requested_dtype = DT_INVALID;
    389   if (dtype != Py_None) {
    390     int32 dtype_as_int = -1;
    391     if (ConvertOneInt32(dtype, &dtype_as_int) == nullptr) {
    392       requested_dtype = static_cast<DataType>(dtype_as_int);
    393     }
    394   }
    395   // NOTE(josh11b): If don't successfully convert to the requested type,
    396   // we just try instead to create a tensor of the inferred type and
    397   // let the caller convert it to the requested type using a cast
    398   // operation.
    399   switch (requested_dtype) {
    400     case DT_FLOAT:
    401       if (ConvertFloat(obj, shape, ret) == nullptr) return Status::OK();
    402       break;
    403 
    404     case DT_DOUBLE:
    405       if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
    406       break;
    407 
    408     case DT_INT64:
    409       if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
    410       break;
    411 
    412     case DT_INT32:
    413       if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK();
    414       break;
    415 
    416     case DT_COMPLEX128:
    417       if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK();
    418       break;
    419 
    420     case DT_STRING:
    421       if (ConvertString(obj, shape, ret) == nullptr) return Status::OK();
    422       break;
    423 
    424     case DT_BOOL:
    425       if (ConvertBool(obj, shape, ret) == nullptr) return Status::OK();
    426       break;
    427 
    428     default:
    429       break;
    430   }
    431   switch (infer_dtype) {
    432     case DT_DOUBLE:
    433       // TODO(josh11b): Handle mixed floats and complex numbers?
    434       if (requested_dtype == DT_INVALID) {
    435         // TensorFlow uses float32s to represent floating point numbers
    436         // by default (for space and speed over using doubles).
    437         RETURN_STRING_AS_STATUS(ConvertFloat(obj, shape, ret));
    438       } else {
    439         // We are going to do a cast to the user's requested dtype
    440         // after this.  We use doubles for this intermediate result so
    441         // we don't lose precision that might be representable in the
    442         // final type.
    443         RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
    444       }
    445 
    446     case DT_INT64:
    447       if (requested_dtype == DT_INVALID) {
    448         const char* error = ConvertInt32(obj, shape, ret);
    449         if (error == ErrorFoundInt64) {
    450           error = ConvertInt64(obj, shape, ret);
    451         }
    452         if (error == ErrorFoundFloat) {
    453           error = ConvertFloat(obj, shape, ret);
    454         }
    455         // TODO(josh11b): May also want to fall back to using doubles if
    456         // error == ErrorOutOfRange?
    457         RETURN_STRING_AS_STATUS(error);
    458       } else {
    459         const char* error = ConvertInt64(obj, shape, ret);
    460         if (error == ErrorFoundFloat) {
    461           error = ConvertDouble(obj, shape, ret);
    462         }
    463         RETURN_STRING_AS_STATUS(error);
    464       }
    465 
    466     case DT_STRING:
    467       RETURN_STRING_AS_STATUS(ConvertString(obj, shape, ret));
    468 
    469     case DT_COMPLEX128:
    470       RETURN_STRING_AS_STATUS(ConvertComplex(obj, shape, ret));
    471 
    472     case DT_BOOL:
    473       RETURN_STRING_AS_STATUS(ConvertBool(obj, shape, ret));
    474 
    475     case DT_INVALID:  // Only occurs for empty tensors.
    476       *ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
    477                     shape);
    478       return Status::OK();
    479 
    480     default:
    481       return errors::Unimplemented("Missing Python -> Tensor conversion for ",
    482                                    DataTypeString(infer_dtype));
    483   }
    484 
    485   return Status::OK();
    486 }
    487 
    488 }  // namespace tensorflow
    489