Home | History | Annotate | Download | only in python
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
     17 #include "tensorflow/compiler/xla/shape_util.h"
     18 #include "tensorflow/core/platform/logging.h"
     19 
     20 namespace xla {
     21 
     22 namespace swig {
     23 
     24 namespace numpy {
     25 
     26 int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
     27   switch (primitive_type) {
     28     case PRED:
     29       return NPY_BOOL;
     30     case S8:
     31       return NPY_INT8;
     32     case S16:
     33       return NPY_INT16;
     34     case S32:
     35       return NPY_INT32;
     36     case S64:
     37       return NPY_INT64;
     38     case U8:
     39       return NPY_UINT8;
     40     case U16:
     41       return NPY_UINT16;
     42     case U32:
     43       return NPY_UINT32;
     44     case U64:
     45       return NPY_UINT64;
     46     case F16:
     47       return NPY_FLOAT16;
     48     case F32:
     49       return NPY_FLOAT32;
     50     case F64:
     51       return NPY_FLOAT64;
     52     case TUPLE:
     53       return NPY_OBJECT;
     54     default:
     55       LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
     56   }
     57 }
     58 
     59 PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
     60   switch (np_type) {
     61     case NPY_BOOL:
     62       return PRED;
     63     case NPY_INT8:
     64       return S8;
     65     case NPY_INT16:
     66       return S16;
     67     case NPY_INT32:
     68       return S32;
     69     case NPY_INT64:
     70       return S64;
     71     case NPY_UINT8:
     72       return U8;
     73     case NPY_UINT16:
     74       return U16;
     75     case NPY_UINT32:
     76       return U32;
     77     case NPY_UINT64:
     78       return U64;
     79     case NPY_FLOAT16:
     80       return F16;
     81     case NPY_FLOAT32:
     82       return F32;
     83     case NPY_FLOAT64:
     84       return F64;
     85     case NPY_OBJECT:
     86       return TUPLE;
     87     default:
     88       LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
     89   }
     90 }
     91 
     92 bool NumpyTypeIsValid(int np_type) {
     93   switch (np_type) {
     94     case NPY_BOOL:
     95     case NPY_INT8:
     96     case NPY_INT16:
     97     case NPY_INT32:
     98     case NPY_INT64:
     99     case NPY_UINT8:
    100     case NPY_UINT16:
    101     case NPY_UINT32:
    102     case NPY_UINT64:
    103     case NPY_FLOAT16:
    104     case NPY_FLOAT32:
    105     case NPY_FLOAT64:
    106     case NPY_OBJECT:
    107       return true;
    108     default:
    109       return false;
    110   }
    111 }
    112 
    113 PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
    114   int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
    115   PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
    116 
    117   PyObject* dimensions;
    118   if (ShapeUtil::IsTuple(shape)) {
    119     int num_elements = ShapeUtil::TupleElementCount(shape);
    120     dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape));
    121     for (int i = 0; i < num_elements; ++i) {
    122       PyTuple_SET_ITEM(
    123           dimensions, i,
    124           PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
    125     }
    126   } else {
    127     int rank = ShapeUtil::Rank(shape);
    128     dimensions = PyTuple_New(rank);
    129     for (int i = 0; i < rank; ++i) {
    130       PyTuple_SET_ITEM(dimensions, i,
    131                        LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
    132     }
    133   }
    134   return PyTuple_Pack(2, np_dtype, dimensions);
    135 }
    136 
    137 // Precondition: o->ob_type == &PyArrayDescr_Type
    138 static int NumpyTypenum(PyObject* o) {
    139   return reinterpret_cast<PyArray_Descr*>(o)->type_num;
    140 }
    141 
    142 // Extracts the string held inside r and returns it as a C++ string.
    143 //
    144 // NOTE: this is an internal helper for conversion to a C++, and so decrefs r.
    145 static string ExtractStringAndDecref(PyObject* r) {
    146   auto error = [r] {
    147     return tensorflow::strings::Printf("<failed conversion of %p>", r);
    148   };
    149   if (r == nullptr) {
    150     return error();
    151   }
    152 #if PY_MAJOR_VERSION < 3
    153   string result = PyString_AsString(r);
    154 #else
    155   PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0);
    156   if (bytes == nullptr) {
    157     return error();
    158   }
    159   CHECK(PyBytes_Check(bytes));
    160   string result = PyBytes_AsString(bytes);
    161   Py_DECREF(bytes);
    162 #endif
    163   Py_DECREF(r);
    164   return result;
    165 }
    166 
    167 // Safely returns a str of the given Python object o as a C++ string.
    168 static string PyObjectCppStr(PyObject* o) {
    169   PyObject* s = PyObject_Str(o);
    170   return ExtractStringAndDecref(s);
    171 }
    172 
    173 // Safely returns a repr of the given Python object o as a C++ string.
    174 static string PyObjectCppRepr(PyObject* o) {
    175   PyObject* r = PyObject_Repr(o);
    176   return ExtractStringAndDecref(r);
    177 }
    178 
    179 StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
    180   auto error = [o](const string& prefix) {
    181     return InvalidArgument("%s; got %s", prefix.c_str(),
    182                            PyObjectCppRepr(o).c_str());
    183   };
    184 
    185   auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
    186     PyObject* result =
    187         PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
    188     if (result == nullptr) {
    189       return error(tensorflow::strings::StrCat(
    190           "Failed to get attribute of Shape object:", field));
    191     }
    192     return result;
    193   };
    194 
    195   auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
    196     PyObject* result =
    197         PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
    198     if (result == nullptr) {
    199       return error(tensorflow::strings::StrCat(
    200           "Failed to call method of shape object:", method));
    201     }
    202     return result;
    203   };
    204 
    205   PyObject* np_type;
    206   TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
    207   if (np_type->ob_type != &PyArrayDescr_Type) {
    208     return error("Shape attribute np_dtype is not an integer numpy dtype");
    209   }
    210   if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
    211     return error("Shape attribute np_dtype is not a valid integer numpy dtype");
    212   }
    213   const PrimitiveType element_type =
    214       NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
    215   Py_DECREF(np_type);
    216 
    217   if (element_type == TUPLE) {
    218     PyObject* py_subshapes;
    219     TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
    220     if (!PyTuple_Check(py_subshapes)) {
    221       return error(
    222           "Return value of Shape method tuple_shapes() is not a tuple");
    223     }
    224     const int length = PyTuple_Size(py_subshapes);
    225     std::vector<Shape> subshapes;
    226     subshapes.reserve(length);
    227     for (int i = 0; i < length; i++) {
    228       TF_ASSIGN_OR_RETURN(
    229           const Shape& subshape,
    230           XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
    231       subshapes.push_back(subshape);
    232     }
    233     Py_DECREF(py_subshapes);
    234     return ShapeUtil::MakeTupleShape(subshapes);
    235   } else {
    236     PyObject* py_dimensions;
    237     PyObject* py_minor_to_major;
    238     TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
    239     TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
    240     if (!PyTuple_Check(py_dimensions)) {
    241       return error("Return value of Shape method dimensions() is not a tuple");
    242     }
    243     if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
    244       return error(
    245           "Return value of Shape method minor_to_major() is neither a tuple "
    246           "nor None");
    247     }
    248     const int length = PyTuple_Size(py_dimensions);
    249     if (py_minor_to_major != Py_None &&
    250         length != PyTuple_Size(py_minor_to_major)) {
    251       return error(
    252           "Shape methods dimensions() and minor_to_major() return "
    253           "different-length tuples");
    254     }
    255     std::vector<int64> dimensions(length);
    256     std::vector<int64> minor_to_major(length);
    257     for (int i = 0; i < length; i++) {
    258       dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
    259       if (dimensions[i] == -1 && PyErr_Occurred()) {
    260         return error("Dimension is not an int");
    261       }
    262 
    263       if (py_minor_to_major != Py_None) {
    264         minor_to_major[i] =
    265             PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
    266         if (minor_to_major[i] == -1 && PyErr_Occurred()) {
    267           return error("Minor-to-major value is not an int");
    268         }
    269       }
    270     }
    271     bool with_layout = py_minor_to_major != Py_None;
    272     Py_DECREF(py_dimensions);
    273     Py_DECREF(py_minor_to_major);
    274     if (with_layout) {
    275       return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
    276                                             minor_to_major);
    277     } else {
    278       return ShapeUtil::MakeShape(element_type, dimensions);
    279     }
    280   }
    281 }
    282 
    283 // Helper that retrieves the member with attr_name, stringifies it if is not
    284 // None, and returns it as a C++ string.
    285 static tensorflow::gtl::optional<string> GetAttrAsString(
    286     PyObject* o, const string& attr_name) {
    287   if (!PyObject_HasAttrString(o, attr_name.c_str())) {
    288     return tensorflow::gtl::nullopt;
    289   }
    290   PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
    291   if (attr == Py_None) {
    292     Py_DECREF(attr);
    293     return tensorflow::gtl::nullopt;
    294   }
    295   string result = PyObjectCppStr(attr);
    296   Py_DECREF(attr);
    297   return result;
    298 }
    299 
    300 // Helper that retrieves the member with attr_name, checks that it is an integer
    301 // if it is not None, and returns it as an int32 value.
    302 static tensorflow::gtl::optional<int32> GetAttrAsInt32(
    303     PyObject* o, const string& attr_name) {
    304   if (!PyObject_HasAttrString(o, attr_name.c_str())) {
    305     return tensorflow::gtl::nullopt;
    306   }
    307   PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
    308   if (attr == Py_None) {
    309     Py_DECREF(attr);
    310     return tensorflow::gtl::nullopt;
    311   }
    312   if (!CheckPyIntOrLong(attr)) {
    313     Py_DECREF(attr);
    314     return tensorflow::gtl::nullopt;
    315   }
    316   long value = PyIntOrPyLongToLong(attr);  // NOLINT
    317   Py_DECREF(attr);
    318   if (value == -1 && PyErr_Occurred() != nullptr) {
    319     return tensorflow::gtl::nullopt;
    320   }
    321   if (static_cast<int32>(value) != value) {
    322     return tensorflow::gtl::nullopt;
    323   }
    324   return value;
    325 }
    326 
    327 StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
    328   OpMetadata result;
    329   tensorflow::gtl::optional<string> op_type = GetAttrAsString(o, "op_type");
    330   if (op_type.has_value()) {
    331     result.set_op_type(op_type.value());
    332   }
    333   tensorflow::gtl::optional<string> op_name = GetAttrAsString(o, "op_name");
    334   if (op_name.has_value()) {
    335     result.set_op_name(op_name.value());
    336   }
    337   tensorflow::gtl::optional<string> source_file =
    338       GetAttrAsString(o, "source_file");
    339   if (source_file.has_value()) {
    340     result.set_source_file(source_file.value());
    341   }
    342   tensorflow::gtl::optional<int32> source_line =
    343       GetAttrAsInt32(o, "source_line");
    344   if (source_line.has_value()) {
    345     result.set_source_line(source_line.value());
    346   }
    347   return result;
    348 }
    349 
    350 PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
    351   if (ShapeUtil::IsTuple(literal.shape())) {
    352     int num_elements = ShapeUtil::TupleElementCount(literal.shape());
    353     PyObject* tuple = PyTuple_New(num_elements);
    354     for (int i = 0; i < num_elements; i++) {
    355       PyTuple_SET_ITEM(
    356           tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
    357     }
    358     return tuple;
    359   } else {
    360     int rank = ShapeUtil::Rank(literal.shape());
    361     std::vector<long> dimensions(rank);  // NOLINT - PyArray requires a long*
    362     for (int i = 0; i < rank; i++) {
    363       dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
    364     }
    365     int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
    366     PyObject* array =
    367         PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
    368     CopyLiteralToNumpyArray(np_type, literal,
    369                             reinterpret_cast<PyArrayObject*>(array));
    370     return array;
    371   }
    372 }
    373 
    374 StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
    375   if (PyTuple_Check(o)) {
    376     int num_elements = PyTuple_Size(o);
    377     std::vector<std::unique_ptr<Literal>> elements;
    378     elements.reserve(num_elements);
    379     for (int i = 0; i < num_elements; i++) {
    380       PyObject* element = PyTuple_GetItem(o, i);
    381       TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
    382       elements.push_back(std::move(literal));
    383     }
    384     return Literal::MakeTupleOwned(std::move(elements));
    385   } else if (PyArray_Check(o)) {
    386     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
    387     int rank = PyArray_NDIM(py_array);
    388     std::vector<int64> dimensions(rank);
    389     for (int i = 0; i < rank; i++) {
    390       dimensions[i] = PyArray_DIM(py_array, i);
    391     }
    392     int np_type = PyArray_TYPE(py_array);
    393     auto literal = Literal::CreateFromDimensions(
    394         NumpyTypeToPrimitiveType(np_type), dimensions);
    395     TF_RETURN_IF_ERROR(
    396         CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
    397     return std::move(literal);
    398   } else {
    399     return InvalidArgument(
    400         "Non-tuple or Numpy array encountered in conversion to XLA literal.");
    401   }
    402 }
    403 
    404 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
    405                                Literal* literal) {
    406   switch (np_type) {
    407     case NPY_BOOL:
    408       CopyNumpyArrayToLiteral<bool>(py_array, literal);
    409       break;
    410     case NPY_INT32:
    411       CopyNumpyArrayToLiteral<int32>(py_array, literal);
    412       break;
    413     case NPY_INT64:
    414       CopyNumpyArrayToLiteral<int64>(py_array, literal);
    415       break;
    416     case NPY_UINT8:
    417       CopyNumpyArrayToLiteral<uint8>(py_array, literal);
    418       break;
    419     case NPY_UINT32:
    420       CopyNumpyArrayToLiteral<uint32>(py_array, literal);
    421       break;
    422     case NPY_UINT64:
    423       CopyNumpyArrayToLiteral<uint64>(py_array, literal);
    424       break;
    425     case NPY_FLOAT16:
    426       CopyNumpyArrayToLiteral<half>(py_array, literal);
    427       break;
    428     case NPY_FLOAT32:
    429       CopyNumpyArrayToLiteral<float>(py_array, literal);
    430       break;
    431     case NPY_FLOAT64:
    432       CopyNumpyArrayToLiteral<double>(py_array, literal);
    433       break;
    434     default:
    435       return InvalidArgument(
    436           "No XLA literal container for Numpy type number: %d", np_type);
    437   }
    438   return Status::OK();
    439 }
    440 
    441 void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
    442                              PyArrayObject* py_array) {
    443   switch (np_type) {
    444     case NPY_BOOL:
    445       CopyLiteralToNumpyArray<bool>(literal, py_array);
    446       break;
    447     case NPY_INT32:
    448       CopyLiteralToNumpyArray<int32>(literal, py_array);
    449       break;
    450     case NPY_INT64:
    451       CopyLiteralToNumpyArray<int64>(literal, py_array);
    452       break;
    453     case NPY_UINT8:
    454       CopyLiteralToNumpyArray<uint8>(literal, py_array);
    455       break;
    456     case NPY_UINT32:
    457       CopyLiteralToNumpyArray<uint32>(literal, py_array);
    458       break;
    459     case NPY_UINT64:
    460       CopyLiteralToNumpyArray<uint64>(literal, py_array);
    461       break;
    462     case NPY_FLOAT16:
    463       CopyLiteralToNumpyArray<half>(literal, py_array);
    464       break;
    465     case NPY_FLOAT32:
    466       CopyLiteralToNumpyArray<float>(literal, py_array);
    467       break;
    468     case NPY_FLOAT64:
    469       CopyLiteralToNumpyArray<double>(literal, py_array);
    470       break;
    471     default:
    472       LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
    473   }
    474 }
    475 
    476 PyObject* LongToPyIntOrPyLong(long x) {  // NOLINT
    477 #if PY_MAJOR_VERSION < 3
    478   return PyInt_FromLong(x);
    479 #else
    480   return PyLong_FromLong(x);
    481 #endif
    482 }
    483 
    484 long PyIntOrPyLongToLong(PyObject* o) {  // NOLINT
    485 #if PY_MAJOR_VERSION < 3
    486   return PyInt_AsLong(o);
    487 #else
    488   return PyLong_AsLong(o);
    489 #endif
    490 }
    491 
    492 bool CheckPyIntOrLong(PyObject* o) {
    493 #if PY_MAJOR_VERSION < 3
    494   return PyInt_Check(o);
    495 #else
    496   if (!PyLong_Check(o)) {
    497     return false;
    498   }
    499   int overflow = 0;
    500   PyLong_AsLongAndOverflow(o, &overflow);
    501   return (overflow == 0);
    502 #endif
    503 }
    504 
    505 PyObject* PyNumberToPyInt(PyObject* o) {
    506 #if PY_MAJOR_VERSION < 3
    507   return PyNumber_Int(o);
    508 #else
    509   return PyNumber_Long(o);
    510 #endif
    511 }
    512 
    513 }  // namespace numpy
    514 
    515 }  // namespace swig
    516 
    517 }  // namespace xla
    518