Home | History | Annotate | Download | only in python
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 // SWIG typemaps and declarations for building, compiling, and
     17 // executing XLA computations, wrapping most of what is declared in
     18 // local_computation_builder.h.
     19 //
     20 // The typemaps below implement/assert the following correspondences
     21 // (with elaborations below):
     22 //
     23 //    C++                                  Python
     24 // -------------------------------------+---------------------------------------
     25 //  ComputationDataHandle              <-> int
     26 //  ArraySlice<int64>                  <-  sequence of int
     27 //  ArraySlice<ComputationDataHandle>  <-  sequence of int
     28 //  Literal                            <-> (nested tuple of) numpy ndarray
     29 //  std::vector<Literal>               <-  sequence of (nested tuple of) ndarray
     30 //  Shape                               -> pair holding (dtype, dimensions)
     31 //                                     <-  object duck-typed as xla_client.Shape
     32 //  std::vector<Shape>                 <-  sequence of xla_client.Shape objects
     33 //  PrimitiveType                      <-  int
     34 //  ArraySlice<pair<int64, in64>>      <-  sequence of int pairs
     35 //  PaddingConfig proto                <-  corresponding Python proto
     36 //  ConvolutionDimensionNumbers proto  <-  corresponding Python proto
     37 //  DotDimensionNumbers proto          <-  corresponding Python proto
     38 //
     39 // Arrows indicate whether a conversion only ever occurs in one
     40 // direction, or whether it is maintained bidirectionally.
     41 //
     42 // The Python objects corresponding to C++ Literals have the type:
     43 //
     44 //   T = ndarray | (T, ...)
     45 //
     46 // where a terminal numpy ndarray translates to a Literal with a
     47 // non-tuple Shape, an XLA primitive element type corresponding to the
     48 // ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
     49 // to a tuple-shaped Literal whose tuple components are translated
     50 // recursively. For example, if x is a numpy ndarray in Python, with
     51 // shape (2, 3) and dtype of dtype('float32'), then x translates to a
     52 // Literal with rank 2, dimension 2 and 3, and XLA primitive type
     53 // F32. Meanwhile,
     54 //
     55 //   (x, (x, x), (x,)),
     56 //
     57 // translates to a tuple-shaped XLA Literal, whose component subshapes
     58 // are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
     59 //
     60 // Shapes output by C++ become Python objects with the type:
     61 //
     62 //   T            = (dtype, S)
     63 //   S            = DIMENSIONS | TUPLE_SHAPES
     64 //   DIMENSIONS   = (int, ...)
     65 //   TUPLE_SHAPES = (T, ...)
     66 //
     67 // In the pair described by the T rule, the terminal dtype determines
     68 // whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
     69 // dtype('O'), numpy's object dtype, the structure represents a tuple
     70 // shape and the expansion of the non-terminal S is
     71 // TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
     72 // and S expands into DIMENSIONS giving dimension sizes. For example:
     73 //
     74 //   (dtype('float32'), (3, 5, 7))
     75 //
     76 // describes a 3x5x7 array of F32s, and
     77 //
     78 //   (dtype('O'), ((dtype('float32'), (2, 3)),
     79 //                 (dtype('float64'), (4, 5))))
     80 //
     81 // describes a tuple shape with two subshapes: the first a 2x3 F32,
     82 // and the other a 4x5 F64.
     83 //
     84 // The Python int corresponding to a PrimitiveType enum must be valid
     85 // per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
     86 //
     87 // The SWIG object wrappers generated by this file are not intended
     88 // for end use, but rather for internal use in the Python XLA client,
     89 // xla_client.py.
     90 //
     91 // One central reason for the Python-side indirection is that the
     92 // Python-side objects produced by the typemaps in this file are
     93 // further packaged up by xla_client before being passed on. For
     94 // instance, xla_client wraps the long produced for a C++
     95 // ComputationDataHandle in a Python ComputationDataHandle proto,
     96 // rather than exposing a raw long outside of the client. Similarly,
     97 // the Python pair produced for a C++ Shape is further wrapped in a
     98 // Python class (xla_client.Shape) so as not to expose the raw pair
     99 // externally.
    100 //
    101 // Other SWIG object wrappers (e.g. of LocalComputation) are further
    102 // wrapped by xla_client in order to set up a custom destructor that
    103 // triggers memory deallocation on the C++ side.
    105 %module(threads="1") local_computation_builder
    107 // Keep the GIL except where explicitly specified.
    108 %nothread;
    110 %include "tensorflow/python/platform/base.i"
    112 %{
    113 // Must be included first
    114 #include "tensorflow/python/lib/core/numpy.h"
    116 #include "tensorflow/compiler/xla/literal_util.h"
    117 #include "tensorflow/compiler/xla/shape_util.h"
    118 #include "tensorflow/compiler/xla/xla_data.pb.h"
    119 #include "tensorflow/core/lib/gtl/array_slice.h"
    120 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
    121 #include "tensorflow/compiler/xla/python/local_computation_builder.h"
    123 using namespace xla;
    124 using namespace xla::swig;
    126 namespace xla {
    127 namespace swig {
    129 bool GetIntAttr(PyObject* o, const char* field, int64* result) {
    130   PyObject* fo = PyObject_GetAttrString(o, field);
    131   if (!fo) {
    132     return false;
    133   }
    134   const int64 value = numpy::PyIntOrPyLongToLong(fo);
    135   if (value == -1 && PyErr_Occurred()) {
    136     Py_DECREF(fo);
    137     return false;
    138   }
    139   Py_DECREF(fo);
    140   *result = value;
    141   return true;
    142 }
    144 }
    145 }
    146 %}
    148 // Required to use PyArray_* functions.
    149 %init %{
    150 tensorflow::ImportNumpy();
    151 %}
    153 // ComputationDataHandle
    155 %typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
    156   const int64 handle = numpy::PyIntOrPyLongToLong($input);
    157   if (handle == -1 && PyErr_Occurred()) {
    158     return NULL;
    159   }
    160   temp.set_handle(handle);
    161   $1 = &temp;
    162 }
    164 %typemap(out) ComputationDataHandle {
    165   $result = numpy::LongToPyIntOrPyLong($1.handle());
    166 }
    168 %typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> {
    169   if ($1.ok()) {
    170     auto* value = $1.ValueOrDie();
    171     {
    172       auto* $1 = value;
    173       $typemap(out, xla::swig::CompiledLocalComputation*)
    174     }
    175   } else {
    176     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    177     return NULL;
    178   }
    179 }
    181 %typemap(out) StatusOr< std::unique_ptr<Literal> > {
    182   if ($1.ok()) {
    183     std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
    184     $result = numpy::PyObjectFromXlaLiteral(*value);
    185   } else {
    186     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    187     return NULL;
    188   }
    189 }
    191 %typemap(out) StatusOr<xla::swig::LocalComputation*> {
    192   if ($1.ok()) {
    193     auto* value = $1.ValueOrDie();
    194     {
    195       auto* $1 = value;
    196       $typemap(out, xla::swig::LocalComputation*)
    197     }
    198   } else {
    199     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    200     return NULL;
    201   }
    202 }
    204 %typemap(out) StatusOr<Shape> {
    205   if ($1.ok()) {
    206     $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
    207   } else {
    208     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    209     return NULL;
    210   }
    211 }
    213 %typemap(out) Status {
    214   if (!$1.ok()) {
    215     PyErr_SetString(
    216         PyExc_RuntimeError, $1.ToString().c_str());
    217     return NULL;
    218   }
    219   $result = Py_None;
    220 }
    222 // ArraySlice<int64>
    224 %typemap(in) tensorflow::gtl::ArraySlice<int64>
    225     (std::vector<int64> temps) {
    226   if (!PySequence_Check($input)) {
    227     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    228     return NULL;
    229   }
    230   const int size = PySequence_Size($input);
    231   temps.resize(size);
    232   for (int i = 0; i < size; ++i) {
    233     PyObject* o = PySequence_GetItem($input, i);
    234     PyObject* py_int = numpy::PyNumberToPyInt(o);
    235     if (!py_int) {
    236       PyErr_SetString(
    237           PyExc_TypeError,
    238           "Argument sequence element cannot be converted to int");
    239       Py_DECREF(o);
    240       return NULL;
    241     }
    242     temps[i] = numpy::PyIntOrPyLongToLong(py_int);
    243     if (temps[i] == -1 && PyErr_Occurred()) {
    244       Py_DECREF(py_int);
    245       Py_DECREF(o);
    246       return NULL;
    247     }
    248     Py_DECREF(py_int);
    249     Py_DECREF(o);
    250   }
    251   $1 = temps;
    252 }
    254 // ComputationDataHandle
    256 %typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
    257     (std::vector<ComputationDataHandle> temps) {
    258   if (!PySequence_Check($input)) {
    259     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    260     return NULL;
    261   }
    262   const int size = PySequence_Size($input);
    263   temps.resize(size);
    264   for (int i = 0; i < size; ++i) {
    265     PyObject* o = PySequence_GetItem($input, i);
    266     PyObject* py_int = numpy::PyNumberToPyInt(o);
    267     if (!py_int) {
    268       PyErr_SetString(
    269           PyExc_TypeError,
    270           "Argument sequence element cannot be converted to int");
    271       return NULL;
    272     }
    273     const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
    274     if (handle == -1 && PyErr_Occurred()) {
    275       Py_DECREF(py_int);
    276       Py_DECREF(o);
    277       return NULL;
    278     }
    279     temps[i].set_handle(handle);
    280     Py_DECREF(py_int);
    281     Py_DECREF(o);
    282   }
    283   $1 = temps;
    284 }
    286 // LocalShapedBuffer*
    288 %typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*>
    289     (std::vector<LocalShapedBuffer*> temps) {
    290   if (!PySequence_Check($input)) {
    291     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    292     return NULL;
    293   }
    294   const int size = PySequence_Size($input);
    295   temps.reserve(size);
    296   for (int i = 0; i < size; ++i) {
    297     PyObject* o = PySequence_GetItem($input, i);
    298     LocalShapedBuffer* lsbp;
    299     if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*),
    300                          SWIG_POINTER_EXCEPTION)) == -1) {
    301       return NULL;
    302     }
    303     temps.push_back(lsbp);
    304     Py_DECREF(o);
    305   }
    306   $1 = temps;
    307 }
    309 // Literal
    311 %typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
    312   literal_status = numpy::XlaLiteralFromPyObject($input);
    313   if (!literal_status.ok()) {
    314     PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
    315     return NULL;
    316   }
    317   $1 = literal_status.ValueOrDie().get();
    318 }
    320 %typemap(out) std::unique_ptr<Literal> {
    321   $result = numpy::PyObjectFromXlaLiteral(*$1);
    322 }
    324 %typemap(out) StatusOr< std::unique_ptr<Literal> > {
    325   if (!$1.ok()) {
    326     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
    327     return NULL;
    328   }
    329   $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
    330 }
    332 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
    333   if (!PySequence_Check($input)) {
    334     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    335     return NULL;
    336   }
    337   const int size = PySequence_Size($input);
    338   for (int i = 0; i < size; ++i) {
    339     PyObject* o = PySequence_GetItem($input, i);
    340     StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
    341     if (!literal_status.ok()) {
    342       PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
    343       Py_DECREF(o);
    344       return NULL;
    345     }
    346     temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
    347     Py_DECREF(o);
    348   }
    349   $1 = &temps;
    350 }
    352 // OpMetadata
    354 %typemap(in) const OpMetadata& (OpMetadata temp) {
    355   StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input);
    356   if (!statusor.ok()) {
    357     PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    358     return NULL;
    359   }
    360   temp = std::move(statusor).ValueOrDie();
    361   $1 = &temp;
    362 }
    364 // Shape
    366 %typemap(in) const Shape& (Shape temp) {
    367   StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
    368   if (!statusor.ok()) {
    369     PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    370     return NULL;
    371   }
    372   temp = std::move(statusor).ValueOrDie();
    373   $1 = &temp;
    374 }
    376 %typemap(in) const tensorflow::gtl::optional<Shape>& (
    377     tensorflow::gtl::optional<Shape> temp) {
    378   if ($input == Py_None) {
    379     temp = tensorflow::gtl::nullopt;
    380     $1 = &temp;
    381   } else {
    382     StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
    383     if (!statusor.ok()) {
    384       PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    385       return NULL;
    386     }
    387     temp = std::move(statusor).ValueOrDie();
    388     $1 = &temp;
    389   }
    390 }
    392 %typemap(out) std::unique_ptr<Shape> {
    393   $result = numpy::PyShapeInfoFromXlaShape(*$1);
    394 }
    396 %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
    397   if (!PySequence_Check($input)) {
    398     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    399     return NULL;
    400   }
    401   const int size = PySequence_Size($input);
    402   for (int i = 0; i < size; ++i) {
    403     PyObject* o = PySequence_GetItem($input, i);
    404     StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
    405     Py_DECREF(o);
    406     if (!statusor.ok()) {
    407       PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    408       return NULL;
    409     }
    410     temps.push_back(statusor.ConsumeValueOrDie());
    411   }
    412   $1 = &temps;
    413 }
    415 %typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
    416     std::vector<tensorflow::gtl::optional<Shape> > temps) {
    417   if (!PySequence_Check($input)) {
    418     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    419     return NULL;
    420   }
    421   const int size = PySequence_Size($input);
    422   for (int i = 0; i < size; ++i) {
    423     PyObject* o = PySequence_GetItem($input, i);
    424     if (o == Py_None) {
    425       temps.push_back(tensorflow::gtl::nullopt);
    426     } else {
    427       StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
    428       Py_DECREF(o);
    429       if (!statusor.ok()) {
    430         PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
    431         return NULL;
    432       }
    433       temps.push_back(statusor.ConsumeValueOrDie());
    434     }
    435   }
    436   $1 = &temps;
    437 }
    439 // PrimitiveType
    441 %typemap(in) PrimitiveType {
    442   PyObject* py_int = numpy::PyNumberToPyInt($input);
    443   if (!py_int) {
    444     PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
    445     return NULL;
    446   }
    447   const long value = numpy::PyIntOrPyLongToLong(py_int);
    448   if (value == -1 && PyErr_Occurred()) {
    449     Py_DECREF(py_int);
    450     return NULL;
    451   }
    452   if (!PrimitiveType_IsValid(value)) {
    453     PyErr_SetString(
    454         PyExc_TypeError, "Argument not valid for PrimitiveType enum");
    455     Py_DECREF(py_int);
    456     return NULL;
    457   }
    458   $1 = static_cast<PrimitiveType>(value);
    459 }
    461 // ArraySlice<pair<int64, in64>>
    463 %typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
    464     (std::vector<std::pair<int64, int64> > temps) {
    465   if (!PySequence_Check($input)) {
    466     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
    467     return NULL;
    468   }
    469   const int size = PySequence_Size($input);
    470   temps.reserve(size);
    471   for (int i = 0; i < size; ++i) {
    472     PyObject* o = PySequence_GetItem($input, i);
    473     if (!o) {
    474       return NULL;
    475     }
    476     PyObject* first = PyTuple_GetItem(o, 0);
    477     if (!first) {
    478       Py_DECREF(o);
    479       return NULL;
    480     }
    481     PyObject* first_pyint = numpy::PyNumberToPyInt(first);
    482     if (!first_pyint) {
    483       PyErr_SetString(
    484           PyExc_TypeError,
    485           "First pair item cannot be converted to int");
    486       Py_DECREF(o);
    487       return NULL;
    488     }
    489     PyObject* second = PyTuple_GetItem(o, 1);
    490     if (!second) {
    491       Py_DECREF(o);
    492       Py_DECREF(first_pyint);
    493       return NULL;
    494     }
    495     PyObject* second_pyint = numpy::PyNumberToPyInt(second);
    496     if (!second_pyint) {
    497       PyErr_SetString(
    498           PyExc_TypeError,
    499           "Second pair item cannot be converted to int");
    500       Py_DECREF(o);
    501       Py_DECREF(first_pyint);
    502       return NULL;
    503     }
    504     const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
    505     if (first_value == -1 && PyErr_Occurred()) {
    506       Py_DECREF(o);
    507       Py_DECREF(first_pyint);
    508       Py_DECREF(second_pyint);
    509       return NULL;
    510     }
    511     const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
    512     if (second_value == -1 && PyErr_Occurred()) {
    513       Py_DECREF(o);
    514       Py_DECREF(first_pyint);
    515       Py_DECREF(second_pyint);
    516       return NULL;
    517     }
    518     temps.push_back(std::make_pair(first_value, second_value));
    519     Py_DECREF(o);
    520   }
    521   $1 = temps;
    522 }
    524 // DotDimensionNumbers
    526 %typemap(in) const DotDimensionNumbers&
    527     (DotDimensionNumbers dimension_numbers) {
    528   int length;
    530   /* lhs_contracting_dimensions */
    531   PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
    532       $input, "lhs_contracting_dimensions");
    533   if (!lhs_contracting_dimensions) {
    534     return NULL;
    535   }
    537   length = PySequence_Size(lhs_contracting_dimensions);
    538   if (length == -1) {
    539     Py_DECREF(lhs_contracting_dimensions);
    540     return NULL;
    541   }
    543   for (int i = 0; i < length; ++i) {
    544     PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
    545     if (!item) {
    546       Py_DECREF(lhs_contracting_dimensions);
    547       return NULL;
    548     }
    549     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    550     if (dimension == -1 && PyErr_Occurred()) {
    551       Py_DECREF(item);
    552       Py_DECREF(lhs_contracting_dimensions);
    553       return NULL;
    554     }
    555     dimension_numbers.add_lhs_contracting_dimensions(dimension);
    556     Py_DECREF(item);
    557   }
    558   Py_DECREF(lhs_contracting_dimensions);
    560   /* rhs_contracting_dimensions */
    561   PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
    562       $input, "rhs_contracting_dimensions");
    563   if (!lhs_contracting_dimensions) {
    564     return NULL;
    565   }
    567   length = PySequence_Size(rhs_contracting_dimensions);
    568   if (length == -1) {
    569     Py_DECREF(rhs_contracting_dimensions);
    570     return NULL;
    571   }
    573   for (int i = 0; i < length; ++i) {
    574     PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
    575     if (!item) {
    576       Py_DECREF(rhs_contracting_dimensions);
    577       return NULL;
    578     }
    579     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    580     if (dimension == -1 && PyErr_Occurred()) {
    581       Py_DECREF(item);
    582       Py_DECREF(rhs_contracting_dimensions);
    583       return NULL;
    584     }
    585     dimension_numbers.add_rhs_contracting_dimensions(dimension);
    586     Py_DECREF(item);
    587   }
    588   Py_DECREF(rhs_contracting_dimensions);
    590   /* lhs_batch_dimensions */
    591   PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
    592       $input, "lhs_batch_dimensions");
    593   if (!lhs_batch_dimensions) {
    594     return NULL;
    595   }
    597   length = PySequence_Size(lhs_batch_dimensions);
    598   if (length == -1) {
    599     Py_DECREF(lhs_batch_dimensions);
    600     return NULL;
    601   }
    603   for (int i = 0; i < length; ++i) {
    604     PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
    605     if (!item) {
    606       Py_DECREF(lhs_batch_dimensions);
    607       return NULL;
    608     }
    609     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    610     if (dimension == -1 && PyErr_Occurred()) {
    611       Py_DECREF(item);
    612       Py_DECREF(lhs_batch_dimensions);
    613       return NULL;
    614     }
    615     dimension_numbers.add_lhs_batch_dimensions(dimension);
    616     Py_DECREF(item);
    617   }
    618   Py_DECREF(lhs_batch_dimensions);
    620   /* rhs_batch_dimensions */
    621   PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
    622       $input, "rhs_batch_dimensions");
    623   if (!rhs_batch_dimensions) {
    624     return NULL;
    625   }
    627   length = PySequence_Size(rhs_batch_dimensions);
    628   if (length == -1) {
    629     Py_DECREF(rhs_batch_dimensions);
    630     return NULL;
    631   }
    633   for (int i = 0; i < length; ++i) {
    634     PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
    635     if (!item) {
    636       Py_DECREF(rhs_batch_dimensions);
    637       return NULL;
    638     }
    639     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    640     if (dimension == -1 && PyErr_Occurred()) {
    641       Py_DECREF(item);
    642       Py_DECREF(rhs_batch_dimensions);
    643       return NULL;
    644     }
    645     dimension_numbers.add_rhs_batch_dimensions(dimension);
    646     Py_DECREF(item);
    647   }
    648   Py_DECREF(rhs_batch_dimensions);
    650   $1 = &dimension_numbers;
    651 }
    653 // PaddingConfig
    655 %typemap(in) const PaddingConfig&
    656     (PaddingConfig padding_config) {
    657   PyObject* dimensions = PyObject_GetAttrString($input, "dimensions");
    658   if (!dimensions) {
    659     return NULL;
    660   }
    662   int length = PySequence_Size(dimensions);
    663   if (length == -1) {
    664     Py_DECREF(dimensions);
    665     return NULL;
    666   }
    668   for (int i = 0; i < length; ++i) {
    669     PyObject* item = PySequence_GetItem(dimensions, i);
    670     if (!item) {
    671       Py_DECREF(dimensions);
    672       return NULL;
    673     }
    674     int64 edge_padding_low, edge_padding_high, interior_padding;
    675     if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low)
    676         || !GetIntAttr(item, "edge_padding_high", &edge_padding_high)
    677         || !GetIntAttr(item, "interior_padding", &interior_padding)) {
    678       Py_DECREF(item);
    679       Py_DECREF(dimensions);
    680       return NULL;
    681     }
    682     Py_DECREF(item);
    684     PaddingConfig::PaddingConfigDimension* dimension =
    685         padding_config.add_dimensions();
    686     dimension->set_edge_padding_low(edge_padding_low);
    687     dimension->set_edge_padding_high(edge_padding_high);
    688     dimension->set_interior_padding(interior_padding);
    689   }
    690   Py_DECREF(dimensions);
    692   $1 = &padding_config;
    693 }
    695 // ConvolutionDimensionNumbers
    697 %typemap(in) const ConvolutionDimensionNumbers&
    698     (ConvolutionDimensionNumbers dimension_numbers) {
    699   int64 value;
    701   if (!GetIntAttr($input, "input_batch_dimension", &value)) {
    702     return NULL;
    703   }
    704   dimension_numbers.set_input_batch_dimension(value);
    706   if (!GetIntAttr($input, "input_feature_dimension", &value)) {
    707     return NULL;
    708   }
    709   dimension_numbers.set_input_feature_dimension(value);
    711   if (!GetIntAttr($input, "output_batch_dimension", &value)) {
    712     return NULL;
    713   }
    714   dimension_numbers.set_output_batch_dimension(value);
    716   if (!GetIntAttr($input, "output_feature_dimension", &value)) {
    717     return NULL;
    718   }
    719   dimension_numbers.set_output_feature_dimension(value);
    721   if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
    722     return NULL;
    723   }
    724   dimension_numbers.set_kernel_output_feature_dimension(value);
    726   if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
    727     return NULL;
    728   }
    729   dimension_numbers.set_kernel_input_feature_dimension(value);
    731   PyObject* o;
    732   int length;
    734   o = PyObject_GetAttrString($input, "input_spatial_dimensions");
    735   if (!o) {
    736     return NULL;
    737   }
    738   length = PySequence_Size(o);
    739   if (length == -1) {
    740     Py_DECREF(o);
    741     return NULL;
    742   }
    743   for (int i = 0; i < length; ++i) {
    744     PyObject* item = PySequence_GetItem(o, i);
    745     if (!item) {
    746       Py_DECREF(o);
    747       return NULL;
    748     }
    749     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    750     if (dimension == -1 && PyErr_Occurred()) {
    751       Py_DECREF(item);
    752       Py_DECREF(o);
    753       return NULL;
    754     }
    755     dimension_numbers.add_input_spatial_dimensions(dimension);
    756     Py_DECREF(item);
    757   }
    758   Py_DECREF(o);
    760   o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
    761   if (!o) {
    762     return NULL;
    763   }
    764   length = PySequence_Size(o);
    765   if (length == -1) {
    766     Py_DECREF(o);
    767     return NULL;
    768   }
    769   for (int i = 0; i < length; ++i) {
    770     PyObject* item = PySequence_GetItem(o, i);
    771     if (!item) {
    772       Py_DECREF(o);
    773       return NULL;
    774     }
    775     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    776     if (dimension == -1 && PyErr_Occurred()) {
    777       Py_DECREF(item);
    778       Py_DECREF(o);
    779       return NULL;
    780     }
    781     dimension_numbers.add_kernel_spatial_dimensions(dimension);
    782     Py_DECREF(item);
    783   }
    784   Py_DECREF(o);
    786   o = PyObject_GetAttrString($input, "output_spatial_dimensions");
    787   if (!o) {
    788     return NULL;
    789   }
    790   length = PySequence_Size(o);
    791   if (length == -1) {
    792     Py_DECREF(o);
    793     return NULL;
    794   }
    795   for (int i = 0; i < length; ++i) {
    796     PyObject* item = PySequence_GetItem(o, i);
    797     if (!item) {
    798       Py_DECREF(o);
    799       return NULL;
    800     }
    801     const int64 dimension = numpy::PyIntOrPyLongToLong(item);
    802     if (dimension == -1 && PyErr_Occurred()) {
    803       Py_DECREF(item);
    804       Py_DECREF(o);
    805       return NULL;
    806     }
    807     dimension_numbers.add_output_spatial_dimensions(dimension);
    808     Py_DECREF(item);
    809   }
    810   Py_DECREF(o);
    812   $1 = &dimension_numbers;
    813 }
    815 // ExecutableBuildOptions
    817 %typemap(in) const ExecutableBuildOptions*
    818     (ExecutableBuildOptions build_options) {
    819   if ($input == Py_None) {
    820     $1 = NULL;
    821   } else {
    822     PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
    823     if (!o) {
    824       return NULL;
    825     }
    826     if (o != Py_None) {
    827       if (!PyString_Check(o)) {
    828         PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
    829         return NULL;
    830       }
    831       build_options.set_generate_hlo_graph(PyString_AsString(o));
    832     }
    833     Py_DECREF(o);
    835     o = PyObject_GetAttrString($input, "result_shape");
    836     if (o == nullptr) {
    837       return nullptr;
    838     }
    839     if (o != Py_None) {
    840       StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
    841       if (!statusor.ok()) {
    842         PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
    843         Py_DECREF(o);
    844         return NULL;
    845       }
    846       build_options.set_result_layout(statusor.ValueOrDie());
    847     }
    848     Py_DECREF(o);
    850     $1 = &build_options;
    851   }
    852 }
    854 %ignoreall
    855 %unignore xla;
    856 %unignore xla::swig;
    857 %unignore xla::swig::InitializeReplicaCount;
    858 %unignore xla::swig::GetReplicaCount;
    859 %unignore xla::swig::TransferToInfeedLocal;
    860 %unignore xla::swig::TransferToInfeedLocalReplica;
    861 %unignore xla::swig::TransferFromOutfeedLocalReplica;
    862 %unignore xla::swig::LocalShapedBuffer;
    863 %unignore xla::swig::LocalShapedBuffer::FromLiteral;
    864 %unignore xla::swig::LocalShapedBuffer::ToLiteral;
    865 %unignore xla::swig::CompiledLocalComputation;
    866 %unignore xla::swig::CompiledLocalComputation::Execute;
    867 %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers;
    868 %unignore xla::swig::LocalComputation;
    869 %unignore xla::swig::LocalComputation::Compile;
    870 %unignore xla::swig::LocalComputation::GetReturnValueShape;
    871 %unignore xla::swig::LocalComputationBuilder;
    872 %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
    873 %unignore xla::swig::LocalComputationBuilder::Build;
    874 %unignore xla::swig::LocalComputationBuilder::SetOpMetadata;
    875 %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata;
    876 %unignore xla::swig::LocalComputationBuilder::Parameter;
    877 %unignore xla::swig::LocalComputationBuilder::GetShape;
    878 %unignore xla::swig::LocalComputationBuilder::GetReturnValueShape;
    879 %unignore xla::swig::LocalComputationBuilder::Infeed;
    880 %unignore xla::swig::LocalComputationBuilder::Outfeed;
    881 %unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
    882 %unignore xla::swig::LocalComputationBuilder::ConstantR0;
    883 %unignore xla::swig::LocalComputationBuilder::Broadcast;
    884 %unignore xla::swig::LocalComputationBuilder::Pad;
    885 %unignore xla::swig::LocalComputationBuilder::Reshape;
    886 %unignore xla::swig::LocalComputationBuilder::Collapse;
    887 %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum;
    888 %unignore xla::swig::LocalComputationBuilder::Slice;
    889 %unignore xla::swig::LocalComputationBuilder::DynamicSlice;
    890 %unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
    891 %unignore xla::swig::LocalComputationBuilder::ConcatInDim;
    892 %unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding;
    893 %unignore xla::swig::LocalComputationBuilder::Select;
    894 %unignore xla::swig::LocalComputationBuilder::Tuple;
    895 %unignore xla::swig::LocalComputationBuilder::GetTupleElement;
    896 %unignore xla::swig::LocalComputationBuilder::ConvertElementType;
    897 %unignore xla::swig::LocalComputationBuilder::Call;
    898 %unignore xla::swig::LocalComputationBuilder::Transpose;
    899 %unignore xla::swig::LocalComputationBuilder::Rev;
    900 %unignore xla::swig::LocalComputationBuilder::Clamp;
    901 %unignore xla::swig::LocalComputationBuilder::Map;
    902 %unignore xla::swig::LocalComputationBuilder::Reduce;
    903 %unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
    904 %unignore xla::swig::LocalComputationBuilder::RngNormal;
    905 %unignore xla::swig::LocalComputationBuilder::RngUniform;
    906 %unignore xla::swig::LocalComputationBuilder::RngBernoulli;
    907 %unignore xla::swig::LocalComputationBuilder::While;
    908 %unignore xla::swig::LocalComputationBuilder::Conditional;
    909 %unignore xla::swig::LocalComputationBuilder::Eq;
    910 %unignore xla::swig::LocalComputationBuilder::Ne;
    911 %unignore xla::swig::LocalComputationBuilder::Ge;
    912 %unignore xla::swig::LocalComputationBuilder::Gt;
    913 %unignore xla::swig::LocalComputationBuilder::Lt;
    914 %unignore xla::swig::LocalComputationBuilder::Le;
    915 %unignore xla::swig::LocalComputationBuilder::Dot;
    916 %unignore xla::swig::LocalComputationBuilder::DotGeneral;
    917 %unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
    918 %unignore xla::swig::LocalComputationBuilder::Add;
    919 %unignore xla::swig::LocalComputationBuilder::Sub;
    920 %unignore xla::swig::LocalComputationBuilder::Mul;
    921 %unignore xla::swig::LocalComputationBuilder::Div;
    922 %unignore xla::swig::LocalComputationBuilder::Rem;
    923 %unignore xla::swig::LocalComputationBuilder::Max;
    924 %unignore xla::swig::LocalComputationBuilder::Min;
    925 %unignore xla::swig::LocalComputationBuilder::And;
    926 %unignore xla::swig::LocalComputationBuilder::Or;
    927 %unignore xla::swig::LocalComputationBuilder::Not;
    928 %unignore xla::swig::LocalComputationBuilder::Abs;
    929 %unignore xla::swig::LocalComputationBuilder::Exp;
    930 %unignore xla::swig::LocalComputationBuilder::Floor;
    931 %unignore xla::swig::LocalComputationBuilder::Ceil;
    932 %unignore xla::swig::LocalComputationBuilder::Round;
    933 %unignore xla::swig::LocalComputationBuilder::Log;
    934 %unignore xla::swig::LocalComputationBuilder::Sign;
    935 %unignore xla::swig::LocalComputationBuilder::Cos;
    936 %unignore xla::swig::LocalComputationBuilder::Sin;
    937 %unignore xla::swig::LocalComputationBuilder::Tanh;
    938 %unignore xla::swig::LocalComputationBuilder::SqrtF32;
    939 %unignore xla::swig::LocalComputationBuilder::SquareF32;
    940 %unignore xla::swig::LocalComputationBuilder::Pow;
    941 %unignore xla::swig::LocalComputationBuilder::IsFinite;
    942 %unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
    943 %unignore xla::swig::LocalComputationBuilder::Neg;
    944 %unignore xla::swig::LocalComputationBuilder::Sort;
    945 %unignore xla::swig::DeleteLocalShapedBuffer;
    946 %unignore xla::swig::DeleteLocalComputation;
    947 %unignore xla::swig::DeleteCompiledLocalComputation;
    949 %thread;
    950 %include "tensorflow/compiler/xla/python/local_computation_builder.h"
    951 %nothread;
    953 %unignoreall