1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // SWIG typemaps and declarations for building, compiling, and 17 // executing XLA computations, wrapping most of what is declared in 18 // local_computation_builder.h. 19 // 20 // The typemaps below implement/assert the following correspondences 21 // (with elaborations below): 22 // 23 // C++ Python 24 // -------------------------------------+--------------------------------------- 25 // ComputationDataHandle <-> int 26 // ArraySlice<int64> <- sequence of int 27 // ArraySlice<ComputationDataHandle> <- sequence of int 28 // Literal <-> (nested tuple of) numpy ndarray 29 // std::vector<Literal> <- sequence of (nested tuple of) ndarray 30 // Shape -> pair holding (dtype, dimensions) 31 // <- object duck-typed as xla_client.Shape 32 // std::vector<Shape> <- sequence of xla_client.Shape objects 33 // PrimitiveType <- int 34 // ArraySlice<pair<int64, in64>> <- sequence of int pairs 35 // PaddingConfig proto <- corresponding Python proto 36 // ConvolutionDimensionNumbers proto <- corresponding Python proto 37 // DotDimensionNumbers proto <- corresponding Python proto 38 // 39 // Arrows indicate whether a conversion only ever occurs in one 40 // direction, or whether it is maintained bidirectionally. 41 // 42 // The Python objects corresponding to C++ Literals have the type: 43 // 44 // T = ndarray | (T, ...) 45 // 46 // where a terminal numpy ndarray translates to a Literal with a 47 // non-tuple Shape, an XLA primitive element type corresponding to the 48 // ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates 49 // to a tuple-shaped Literal whose tuple components are translated 50 // recursively. For example, if x is a numpy ndarray in Python, with 51 // shape (2, 3) and dtype of dtype('float32'), then x translates to a 52 // Literal with rank 2, dimension 2 and 3, and XLA primitive type 53 // F32. Meanwhile, 54 // 55 // (x, (x, x), (x,)), 56 // 57 // translates to a tuple-shaped XLA Literal, whose component subshapes 58 // are a 2x3 F32-shaped literal followed by two tuple-shaped literals. 59 // 60 // Shapes output by C++ become Python objects with the type: 61 // 62 // T = (dtype, S) 63 // S = DIMENSIONS | TUPLE_SHAPES 64 // DIMENSIONS = (int, ...) 65 // TUPLE_SHAPES = (T, ...) 66 // 67 // In the pair described by the T rule, the terminal dtype determines 68 // whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is 69 // dtype('O'), numpy's object dtype, the structure represents a tuple 70 // shape and the expansion of the non-terminal S is 71 // TUPLE_SHAPES. Otherwise, dtype describes a primitive element type 72 // and S expands into DIMENSIONS giving dimension sizes. For example: 73 // 74 // (dtype('float32'), (3, 5, 7)) 75 // 76 // describes a 3x5x7 array of F32s, and 77 // 78 // (dtype('O'), ((dtype('float32'), (2, 3)), 79 // (dtype('float64'), (4, 5)))) 80 // 81 // describes a tuple shape with two subshapes: the first a 2x3 F32, 82 // and the other a 4x5 F64. 83 // 84 // The Python int corresponding to a PrimitiveType enum must be valid 85 // per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). 86 // 87 // The SWIG object wrappers generated by this file are not intended 88 // for end use, but rather for internal use in the Python XLA client, 89 // xla_client.py. 90 // 91 // One central reason for the Python-side indirection is that the 92 // Python-side objects produced by the typemaps in this file are 93 // further packaged up by xla_client before being passed on. For 94 // instance, xla_client wraps the long produced for a C++ 95 // ComputationDataHandle in a Python ComputationDataHandle proto, 96 // rather than exposing a raw long outside of the client. Similarly, 97 // the Python pair produced for a C++ Shape is further wrapped in a 98 // Python class (xla_client.Shape) so as not to expose the raw pair 99 // externally. 100 // 101 // Other SWIG object wrappers (e.g. of LocalComputation) are further 102 // wrapped by xla_client in order to set up a custom destructor that 103 // triggers memory deallocation on the C++ side. 104 105 %module(threads="1") local_computation_builder 106 107 // Keep the GIL except where explicitly specified. 108 %nothread; 109 110 %include "tensorflow/python/platform/base.i" 111 112 %{ 113 // Must be included first 114 #include "tensorflow/python/lib/core/numpy.h" 115 116 #include "tensorflow/compiler/xla/literal_util.h" 117 #include "tensorflow/compiler/xla/shape_util.h" 118 #include "tensorflow/compiler/xla/xla_data.pb.h" 119 #include "tensorflow/core/lib/gtl/array_slice.h" 120 #include "tensorflow/compiler/xla/python/numpy_bridge.h" 121 #include "tensorflow/compiler/xla/python/local_computation_builder.h" 122 123 using namespace xla; 124 using namespace xla::swig; 125 126 namespace xla { 127 namespace swig { 128 129 bool GetIntAttr(PyObject* o, const char* field, int64* result) { 130 PyObject* fo = PyObject_GetAttrString(o, field); 131 if (!fo) { 132 return false; 133 } 134 const int64 value = numpy::PyIntOrPyLongToLong(fo); 135 if (value == -1 && PyErr_Occurred()) { 136 Py_DECREF(fo); 137 return false; 138 } 139 Py_DECREF(fo); 140 *result = value; 141 return true; 142 } 143 144 } 145 } 146 %} 147 148 // Required to use PyArray_* functions. 149 %init %{ 150 tensorflow::ImportNumpy(); 151 %} 152 153 // ComputationDataHandle 154 155 %typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { 156 const int64 handle = numpy::PyIntOrPyLongToLong($input); 157 if (handle == -1 && PyErr_Occurred()) { 158 return NULL; 159 } 160 temp.set_handle(handle); 161 $1 = &temp; 162 } 163 164 %typemap(out) ComputationDataHandle { 165 $result = numpy::LongToPyIntOrPyLong($1.handle()); 166 } 167 168 %typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> { 169 if ($1.ok()) { 170 auto* value = $1.ValueOrDie(); 171 { 172 auto* $1 = value; 173 $typemap(out, xla::swig::CompiledLocalComputation*) 174 } 175 } else { 176 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 177 return NULL; 178 } 179 } 180 181 %typemap(out) StatusOr< std::unique_ptr<Literal> > { 182 if ($1.ok()) { 183 std::unique_ptr<Literal> value = $1.ConsumeValueOrDie(); 184 $result = numpy::PyObjectFromXlaLiteral(*value); 185 } else { 186 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 187 return NULL; 188 } 189 } 190 191 %typemap(out) StatusOr<xla::swig::LocalComputation*> { 192 if ($1.ok()) { 193 auto* value = $1.ValueOrDie(); 194 { 195 auto* $1 = value; 196 $typemap(out, xla::swig::LocalComputation*) 197 } 198 } else { 199 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 200 return NULL; 201 } 202 } 203 204 %typemap(out) StatusOr<Shape> { 205 if ($1.ok()) { 206 $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); 207 } else { 208 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 209 return NULL; 210 } 211 } 212 213 %typemap(out) Status { 214 if (!$1.ok()) { 215 PyErr_SetString( 216 PyExc_RuntimeError, $1.ToString().c_str()); 217 return NULL; 218 } 219 $result = Py_None; 220 } 221 222 // ArraySlice<int64> 223 224 %typemap(in) tensorflow::gtl::ArraySlice<int64> 225 (std::vector<int64> temps) { 226 if (!PySequence_Check($input)) { 227 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 228 return NULL; 229 } 230 const int size = PySequence_Size($input); 231 temps.resize(size); 232 for (int i = 0; i < size; ++i) { 233 PyObject* o = PySequence_GetItem($input, i); 234 PyObject* py_int = numpy::PyNumberToPyInt(o); 235 if (!py_int) { 236 PyErr_SetString( 237 PyExc_TypeError, 238 "Argument sequence element cannot be converted to int"); 239 Py_DECREF(o); 240 return NULL; 241 } 242 temps[i] = numpy::PyIntOrPyLongToLong(py_int); 243 if (temps[i] == -1 && PyErr_Occurred()) { 244 Py_DECREF(py_int); 245 Py_DECREF(o); 246 return NULL; 247 } 248 Py_DECREF(py_int); 249 Py_DECREF(o); 250 } 251 $1 = temps; 252 } 253 254 // ComputationDataHandle 255 256 %typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle> 257 (std::vector<ComputationDataHandle> temps) { 258 if (!PySequence_Check($input)) { 259 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 260 return NULL; 261 } 262 const int size = PySequence_Size($input); 263 temps.resize(size); 264 for (int i = 0; i < size; ++i) { 265 PyObject* o = PySequence_GetItem($input, i); 266 PyObject* py_int = numpy::PyNumberToPyInt(o); 267 if (!py_int) { 268 PyErr_SetString( 269 PyExc_TypeError, 270 "Argument sequence element cannot be converted to int"); 271 return NULL; 272 } 273 const int64 handle = numpy::PyIntOrPyLongToLong(py_int); 274 if (handle == -1 && PyErr_Occurred()) { 275 Py_DECREF(py_int); 276 Py_DECREF(o); 277 return NULL; 278 } 279 temps[i].set_handle(handle); 280 Py_DECREF(py_int); 281 Py_DECREF(o); 282 } 283 $1 = temps; 284 } 285 286 // LocalShapedBuffer* 287 288 %typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*> 289 (std::vector<LocalShapedBuffer*> temps) { 290 if (!PySequence_Check($input)) { 291 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 292 return NULL; 293 } 294 const int size = PySequence_Size($input); 295 temps.reserve(size); 296 for (int i = 0; i < size; ++i) { 297 PyObject* o = PySequence_GetItem($input, i); 298 LocalShapedBuffer* lsbp; 299 if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), 300 SWIG_POINTER_EXCEPTION)) == -1) { 301 return NULL; 302 } 303 temps.push_back(lsbp); 304 Py_DECREF(o); 305 } 306 $1 = temps; 307 } 308 309 // Literal 310 311 %typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) { 312 literal_status = numpy::XlaLiteralFromPyObject($input); 313 if (!literal_status.ok()) { 314 PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); 315 return NULL; 316 } 317 $1 = literal_status.ValueOrDie().get(); 318 } 319 320 %typemap(out) std::unique_ptr<Literal> { 321 $result = numpy::PyObjectFromXlaLiteral(*$1); 322 } 323 324 %typemap(out) StatusOr< std::unique_ptr<Literal> > { 325 if (!$1.ok()) { 326 PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); 327 return NULL; 328 } 329 $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); 330 } 331 332 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) { 333 if (!PySequence_Check($input)) { 334 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 335 return NULL; 336 } 337 const int size = PySequence_Size($input); 338 for (int i = 0; i < size; ++i) { 339 PyObject* o = PySequence_GetItem($input, i); 340 StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o); 341 if (!literal_status.ok()) { 342 PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); 343 Py_DECREF(o); 344 return NULL; 345 } 346 temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); 347 Py_DECREF(o); 348 } 349 $1 = &temps; 350 } 351 352 // OpMetadata 353 354 %typemap(in) const OpMetadata& (OpMetadata temp) { 355 StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input); 356 if (!statusor.ok()) { 357 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 358 return NULL; 359 } 360 temp = std::move(statusor).ValueOrDie(); 361 $1 = &temp; 362 } 363 364 // Shape 365 366 %typemap(in) const Shape& (Shape temp) { 367 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); 368 if (!statusor.ok()) { 369 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 370 return NULL; 371 } 372 temp = std::move(statusor).ValueOrDie(); 373 $1 = &temp; 374 } 375 376 %typemap(in) const tensorflow::gtl::optional<Shape>& ( 377 tensorflow::gtl::optional<Shape> temp) { 378 if ($input == Py_None) { 379 temp = tensorflow::gtl::nullopt; 380 $1 = &temp; 381 } else { 382 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input); 383 if (!statusor.ok()) { 384 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 385 return NULL; 386 } 387 temp = std::move(statusor).ValueOrDie(); 388 $1 = &temp; 389 } 390 } 391 392 %typemap(out) std::unique_ptr<Shape> { 393 $result = numpy::PyShapeInfoFromXlaShape(*$1); 394 } 395 396 %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) { 397 if (!PySequence_Check($input)) { 398 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 399 return NULL; 400 } 401 const int size = PySequence_Size($input); 402 for (int i = 0; i < size; ++i) { 403 PyObject* o = PySequence_GetItem($input, i); 404 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 405 Py_DECREF(o); 406 if (!statusor.ok()) { 407 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 408 return NULL; 409 } 410 temps.push_back(statusor.ConsumeValueOrDie()); 411 } 412 $1 = &temps; 413 } 414 415 %typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& ( 416 std::vector<tensorflow::gtl::optional<Shape> > temps) { 417 if (!PySequence_Check($input)) { 418 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 419 return NULL; 420 } 421 const int size = PySequence_Size($input); 422 for (int i = 0; i < size; ++i) { 423 PyObject* o = PySequence_GetItem($input, i); 424 if (o == Py_None) { 425 temps.push_back(tensorflow::gtl::nullopt); 426 } else { 427 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 428 Py_DECREF(o); 429 if (!statusor.ok()) { 430 PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); 431 return NULL; 432 } 433 temps.push_back(statusor.ConsumeValueOrDie()); 434 } 435 } 436 $1 = &temps; 437 } 438 439 // PrimitiveType 440 441 %typemap(in) PrimitiveType { 442 PyObject* py_int = numpy::PyNumberToPyInt($input); 443 if (!py_int) { 444 PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); 445 return NULL; 446 } 447 const long value = numpy::PyIntOrPyLongToLong(py_int); 448 if (value == -1 && PyErr_Occurred()) { 449 Py_DECREF(py_int); 450 return NULL; 451 } 452 if (!PrimitiveType_IsValid(value)) { 453 PyErr_SetString( 454 PyExc_TypeError, "Argument not valid for PrimitiveType enum"); 455 Py_DECREF(py_int); 456 return NULL; 457 } 458 $1 = static_cast<PrimitiveType>(value); 459 } 460 461 // ArraySlice<pair<int64, in64>> 462 463 %typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> > 464 (std::vector<std::pair<int64, int64> > temps) { 465 if (!PySequence_Check($input)) { 466 PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); 467 return NULL; 468 } 469 const int size = PySequence_Size($input); 470 temps.reserve(size); 471 for (int i = 0; i < size; ++i) { 472 PyObject* o = PySequence_GetItem($input, i); 473 if (!o) { 474 return NULL; 475 } 476 PyObject* first = PyTuple_GetItem(o, 0); 477 if (!first) { 478 Py_DECREF(o); 479 return NULL; 480 } 481 PyObject* first_pyint = numpy::PyNumberToPyInt(first); 482 if (!first_pyint) { 483 PyErr_SetString( 484 PyExc_TypeError, 485 "First pair item cannot be converted to int"); 486 Py_DECREF(o); 487 return NULL; 488 } 489 PyObject* second = PyTuple_GetItem(o, 1); 490 if (!second) { 491 Py_DECREF(o); 492 Py_DECREF(first_pyint); 493 return NULL; 494 } 495 PyObject* second_pyint = numpy::PyNumberToPyInt(second); 496 if (!second_pyint) { 497 PyErr_SetString( 498 PyExc_TypeError, 499 "Second pair item cannot be converted to int"); 500 Py_DECREF(o); 501 Py_DECREF(first_pyint); 502 return NULL; 503 } 504 const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); 505 if (first_value == -1 && PyErr_Occurred()) { 506 Py_DECREF(o); 507 Py_DECREF(first_pyint); 508 Py_DECREF(second_pyint); 509 return NULL; 510 } 511 const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); 512 if (second_value == -1 && PyErr_Occurred()) { 513 Py_DECREF(o); 514 Py_DECREF(first_pyint); 515 Py_DECREF(second_pyint); 516 return NULL; 517 } 518 temps.push_back(std::make_pair(first_value, second_value)); 519 Py_DECREF(o); 520 } 521 $1 = temps; 522 } 523 524 // DotDimensionNumbers 525 526 %typemap(in) const DotDimensionNumbers& 527 (DotDimensionNumbers dimension_numbers) { 528 int length; 529 530 /* lhs_contracting_dimensions */ 531 PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( 532 $input, "lhs_contracting_dimensions"); 533 if (!lhs_contracting_dimensions) { 534 return NULL; 535 } 536 537 length = PySequence_Size(lhs_contracting_dimensions); 538 if (length == -1) { 539 Py_DECREF(lhs_contracting_dimensions); 540 return NULL; 541 } 542 543 for (int i = 0; i < length; ++i) { 544 PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); 545 if (!item) { 546 Py_DECREF(lhs_contracting_dimensions); 547 return NULL; 548 } 549 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 550 if (dimension == -1 && PyErr_Occurred()) { 551 Py_DECREF(item); 552 Py_DECREF(lhs_contracting_dimensions); 553 return NULL; 554 } 555 dimension_numbers.add_lhs_contracting_dimensions(dimension); 556 Py_DECREF(item); 557 } 558 Py_DECREF(lhs_contracting_dimensions); 559 560 /* rhs_contracting_dimensions */ 561 PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( 562 $input, "rhs_contracting_dimensions"); 563 if (!lhs_contracting_dimensions) { 564 return NULL; 565 } 566 567 length = PySequence_Size(rhs_contracting_dimensions); 568 if (length == -1) { 569 Py_DECREF(rhs_contracting_dimensions); 570 return NULL; 571 } 572 573 for (int i = 0; i < length; ++i) { 574 PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); 575 if (!item) { 576 Py_DECREF(rhs_contracting_dimensions); 577 return NULL; 578 } 579 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 580 if (dimension == -1 && PyErr_Occurred()) { 581 Py_DECREF(item); 582 Py_DECREF(rhs_contracting_dimensions); 583 return NULL; 584 } 585 dimension_numbers.add_rhs_contracting_dimensions(dimension); 586 Py_DECREF(item); 587 } 588 Py_DECREF(rhs_contracting_dimensions); 589 590 /* lhs_batch_dimensions */ 591 PyObject* lhs_batch_dimensions = PyObject_GetAttrString( 592 $input, "lhs_batch_dimensions"); 593 if (!lhs_batch_dimensions) { 594 return NULL; 595 } 596 597 length = PySequence_Size(lhs_batch_dimensions); 598 if (length == -1) { 599 Py_DECREF(lhs_batch_dimensions); 600 return NULL; 601 } 602 603 for (int i = 0; i < length; ++i) { 604 PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); 605 if (!item) { 606 Py_DECREF(lhs_batch_dimensions); 607 return NULL; 608 } 609 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 610 if (dimension == -1 && PyErr_Occurred()) { 611 Py_DECREF(item); 612 Py_DECREF(lhs_batch_dimensions); 613 return NULL; 614 } 615 dimension_numbers.add_lhs_batch_dimensions(dimension); 616 Py_DECREF(item); 617 } 618 Py_DECREF(lhs_batch_dimensions); 619 620 /* rhs_batch_dimensions */ 621 PyObject* rhs_batch_dimensions = PyObject_GetAttrString( 622 $input, "rhs_batch_dimensions"); 623 if (!rhs_batch_dimensions) { 624 return NULL; 625 } 626 627 length = PySequence_Size(rhs_batch_dimensions); 628 if (length == -1) { 629 Py_DECREF(rhs_batch_dimensions); 630 return NULL; 631 } 632 633 for (int i = 0; i < length; ++i) { 634 PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); 635 if (!item) { 636 Py_DECREF(rhs_batch_dimensions); 637 return NULL; 638 } 639 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 640 if (dimension == -1 && PyErr_Occurred()) { 641 Py_DECREF(item); 642 Py_DECREF(rhs_batch_dimensions); 643 return NULL; 644 } 645 dimension_numbers.add_rhs_batch_dimensions(dimension); 646 Py_DECREF(item); 647 } 648 Py_DECREF(rhs_batch_dimensions); 649 650 $1 = &dimension_numbers; 651 } 652 653 // PaddingConfig 654 655 %typemap(in) const PaddingConfig& 656 (PaddingConfig padding_config) { 657 PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); 658 if (!dimensions) { 659 return NULL; 660 } 661 662 int length = PySequence_Size(dimensions); 663 if (length == -1) { 664 Py_DECREF(dimensions); 665 return NULL; 666 } 667 668 for (int i = 0; i < length; ++i) { 669 PyObject* item = PySequence_GetItem(dimensions, i); 670 if (!item) { 671 Py_DECREF(dimensions); 672 return NULL; 673 } 674 int64 edge_padding_low, edge_padding_high, interior_padding; 675 if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) 676 || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) 677 || !GetIntAttr(item, "interior_padding", &interior_padding)) { 678 Py_DECREF(item); 679 Py_DECREF(dimensions); 680 return NULL; 681 } 682 Py_DECREF(item); 683 684 PaddingConfig::PaddingConfigDimension* dimension = 685 padding_config.add_dimensions(); 686 dimension->set_edge_padding_low(edge_padding_low); 687 dimension->set_edge_padding_high(edge_padding_high); 688 dimension->set_interior_padding(interior_padding); 689 } 690 Py_DECREF(dimensions); 691 692 $1 = &padding_config; 693 } 694 695 // ConvolutionDimensionNumbers 696 697 %typemap(in) const ConvolutionDimensionNumbers& 698 (ConvolutionDimensionNumbers dimension_numbers) { 699 int64 value; 700 701 if (!GetIntAttr($input, "input_batch_dimension", &value)) { 702 return NULL; 703 } 704 dimension_numbers.set_input_batch_dimension(value); 705 706 if (!GetIntAttr($input, "input_feature_dimension", &value)) { 707 return NULL; 708 } 709 dimension_numbers.set_input_feature_dimension(value); 710 711 if (!GetIntAttr($input, "output_batch_dimension", &value)) { 712 return NULL; 713 } 714 dimension_numbers.set_output_batch_dimension(value); 715 716 if (!GetIntAttr($input, "output_feature_dimension", &value)) { 717 return NULL; 718 } 719 dimension_numbers.set_output_feature_dimension(value); 720 721 if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { 722 return NULL; 723 } 724 dimension_numbers.set_kernel_output_feature_dimension(value); 725 726 if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { 727 return NULL; 728 } 729 dimension_numbers.set_kernel_input_feature_dimension(value); 730 731 PyObject* o; 732 int length; 733 734 o = PyObject_GetAttrString($input, "input_spatial_dimensions"); 735 if (!o) { 736 return NULL; 737 } 738 length = PySequence_Size(o); 739 if (length == -1) { 740 Py_DECREF(o); 741 return NULL; 742 } 743 for (int i = 0; i < length; ++i) { 744 PyObject* item = PySequence_GetItem(o, i); 745 if (!item) { 746 Py_DECREF(o); 747 return NULL; 748 } 749 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 750 if (dimension == -1 && PyErr_Occurred()) { 751 Py_DECREF(item); 752 Py_DECREF(o); 753 return NULL; 754 } 755 dimension_numbers.add_input_spatial_dimensions(dimension); 756 Py_DECREF(item); 757 } 758 Py_DECREF(o); 759 760 o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); 761 if (!o) { 762 return NULL; 763 } 764 length = PySequence_Size(o); 765 if (length == -1) { 766 Py_DECREF(o); 767 return NULL; 768 } 769 for (int i = 0; i < length; ++i) { 770 PyObject* item = PySequence_GetItem(o, i); 771 if (!item) { 772 Py_DECREF(o); 773 return NULL; 774 } 775 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 776 if (dimension == -1 && PyErr_Occurred()) { 777 Py_DECREF(item); 778 Py_DECREF(o); 779 return NULL; 780 } 781 dimension_numbers.add_kernel_spatial_dimensions(dimension); 782 Py_DECREF(item); 783 } 784 Py_DECREF(o); 785 786 o = PyObject_GetAttrString($input, "output_spatial_dimensions"); 787 if (!o) { 788 return NULL; 789 } 790 length = PySequence_Size(o); 791 if (length == -1) { 792 Py_DECREF(o); 793 return NULL; 794 } 795 for (int i = 0; i < length; ++i) { 796 PyObject* item = PySequence_GetItem(o, i); 797 if (!item) { 798 Py_DECREF(o); 799 return NULL; 800 } 801 const int64 dimension = numpy::PyIntOrPyLongToLong(item); 802 if (dimension == -1 && PyErr_Occurred()) { 803 Py_DECREF(item); 804 Py_DECREF(o); 805 return NULL; 806 } 807 dimension_numbers.add_output_spatial_dimensions(dimension); 808 Py_DECREF(item); 809 } 810 Py_DECREF(o); 811 812 $1 = &dimension_numbers; 813 } 814 815 // ExecutableBuildOptions 816 817 %typemap(in) const ExecutableBuildOptions* 818 (ExecutableBuildOptions build_options) { 819 if ($input == Py_None) { 820 $1 = NULL; 821 } else { 822 PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); 823 if (!o) { 824 return NULL; 825 } 826 if (o != Py_None) { 827 if (!PyString_Check(o)) { 828 PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); 829 return NULL; 830 } 831 build_options.set_generate_hlo_graph(PyString_AsString(o)); 832 } 833 Py_DECREF(o); 834 835 o = PyObject_GetAttrString($input, "result_shape"); 836 if (o == nullptr) { 837 return nullptr; 838 } 839 if (o != Py_None) { 840 StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o); 841 if (!statusor.ok()) { 842 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); 843 Py_DECREF(o); 844 return NULL; 845 } 846 build_options.set_result_layout(statusor.ValueOrDie()); 847 } 848 Py_DECREF(o); 849 850 $1 = &build_options; 851 } 852 } 853 854 %ignoreall 855 %unignore xla; 856 %unignore xla::swig; 857 %unignore xla::swig::InitializeReplicaCount; 858 %unignore xla::swig::GetReplicaCount; 859 %unignore xla::swig::TransferToInfeedLocal; 860 %unignore xla::swig::TransferToInfeedLocalReplica; 861 %unignore xla::swig::TransferFromOutfeedLocalReplica; 862 %unignore xla::swig::LocalShapedBuffer; 863 %unignore xla::swig::LocalShapedBuffer::FromLiteral; 864 %unignore xla::swig::LocalShapedBuffer::ToLiteral; 865 %unignore xla::swig::CompiledLocalComputation; 866 %unignore xla::swig::CompiledLocalComputation::Execute; 867 %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; 868 %unignore xla::swig::LocalComputation; 869 %unignore xla::swig::LocalComputation::Compile; 870 %unignore xla::swig::LocalComputation::GetReturnValueShape; 871 %unignore xla::swig::LocalComputationBuilder; 872 %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; 873 %unignore xla::swig::LocalComputationBuilder::Build; 874 %unignore xla::swig::LocalComputationBuilder::SetOpMetadata; 875 %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; 876 %unignore xla::swig::LocalComputationBuilder::Parameter; 877 %unignore xla::swig::LocalComputationBuilder::GetShape; 878 %unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; 879 %unignore xla::swig::LocalComputationBuilder::Infeed; 880 %unignore xla::swig::LocalComputationBuilder::Outfeed; 881 %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; 882 %unignore xla::swig::LocalComputationBuilder::ConstantR0; 883 %unignore xla::swig::LocalComputationBuilder::Broadcast; 884 %unignore xla::swig::LocalComputationBuilder::Pad; 885 %unignore xla::swig::LocalComputationBuilder::Reshape; 886 %unignore xla::swig::LocalComputationBuilder::Collapse; 887 %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; 888 %unignore xla::swig::LocalComputationBuilder::Slice; 889 %unignore xla::swig::LocalComputationBuilder::DynamicSlice; 890 %unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; 891 %unignore xla::swig::LocalComputationBuilder::ConcatInDim; 892 %unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; 893 %unignore xla::swig::LocalComputationBuilder::Select; 894 %unignore xla::swig::LocalComputationBuilder::Tuple; 895 %unignore xla::swig::LocalComputationBuilder::GetTupleElement; 896 %unignore xla::swig::LocalComputationBuilder::ConvertElementType; 897 %unignore xla::swig::LocalComputationBuilder::Call; 898 %unignore xla::swig::LocalComputationBuilder::Transpose; 899 %unignore xla::swig::LocalComputationBuilder::Rev; 900 %unignore xla::swig::LocalComputationBuilder::Clamp; 901 %unignore xla::swig::LocalComputationBuilder::Map; 902 %unignore xla::swig::LocalComputationBuilder::Reduce; 903 %unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; 904 %unignore xla::swig::LocalComputationBuilder::RngNormal; 905 %unignore xla::swig::LocalComputationBuilder::RngUniform; 906 %unignore xla::swig::LocalComputationBuilder::RngBernoulli; 907 %unignore xla::swig::LocalComputationBuilder::While; 908 %unignore xla::swig::LocalComputationBuilder::Conditional; 909 %unignore xla::swig::LocalComputationBuilder::Eq; 910 %unignore xla::swig::LocalComputationBuilder::Ne; 911 %unignore xla::swig::LocalComputationBuilder::Ge; 912 %unignore xla::swig::LocalComputationBuilder::Gt; 913 %unignore xla::swig::LocalComputationBuilder::Lt; 914 %unignore xla::swig::LocalComputationBuilder::Le; 915 %unignore xla::swig::LocalComputationBuilder::Dot; 916 %unignore xla::swig::LocalComputationBuilder::DotGeneral; 917 %unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; 918 %unignore xla::swig::LocalComputationBuilder::Add; 919 %unignore xla::swig::LocalComputationBuilder::Sub; 920 %unignore xla::swig::LocalComputationBuilder::Mul; 921 %unignore xla::swig::LocalComputationBuilder::Div; 922 %unignore xla::swig::LocalComputationBuilder::Rem; 923 %unignore xla::swig::LocalComputationBuilder::Max; 924 %unignore xla::swig::LocalComputationBuilder::Min; 925 %unignore xla::swig::LocalComputationBuilder::And; 926 %unignore xla::swig::LocalComputationBuilder::Or; 927 %unignore xla::swig::LocalComputationBuilder::Not; 928 %unignore xla::swig::LocalComputationBuilder::Abs; 929 %unignore xla::swig::LocalComputationBuilder::Exp; 930 %unignore xla::swig::LocalComputationBuilder::Floor; 931 %unignore xla::swig::LocalComputationBuilder::Ceil; 932 %unignore xla::swig::LocalComputationBuilder::Round; 933 %unignore xla::swig::LocalComputationBuilder::Log; 934 %unignore xla::swig::LocalComputationBuilder::Sign; 935 %unignore xla::swig::LocalComputationBuilder::Cos; 936 %unignore xla::swig::LocalComputationBuilder::Sin; 937 %unignore xla::swig::LocalComputationBuilder::Tanh; 938 %unignore xla::swig::LocalComputationBuilder::SqrtF32; 939 %unignore xla::swig::LocalComputationBuilder::SquareF32; 940 %unignore xla::swig::LocalComputationBuilder::Pow; 941 %unignore xla::swig::LocalComputationBuilder::IsFinite; 942 %unignore xla::swig::LocalComputationBuilder::ReciprocalF32; 943 %unignore xla::swig::LocalComputationBuilder::Neg; 944 %unignore xla::swig::LocalComputationBuilder::Sort; 945 %unignore xla::swig::DeleteLocalShapedBuffer; 946 %unignore xla::swig::DeleteLocalComputation; 947 %unignore xla::swig::DeleteCompiledLocalComputation; 948 949 %thread; 950 %include "tensorflow/compiler/xla/python/local_computation_builder.h" 951 %nothread; 952 953 %unignoreall 954