1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 %include "tensorflow/python/platform/base.i" 17 18 %{ 19 20 #include "tensorflow/c/python_api.h" 21 #include "tensorflow/python/client/tf_session_helper.h" 22 #include "tensorflow/core/framework/session_state.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/strings/stringprintf.h" 25 #include "tensorflow/core/public/version.h" 26 27 // Helper function to convert a Python list of Tensors to a C++ vector of 28 // TF_Outputs. 29 // 30 // Returns true if successful. Otherwise, returns false and sets error_msg. 31 bool PyTensorListToVector(PyObject* py_tensor_list, 32 std::vector<TF_Output>* vec, 33 string* error_msg) { 34 if (!PyList_Check(py_tensor_list)) { 35 *error_msg = "expected Python list."; 36 return false; 37 } 38 size_t size = PyList_Size(py_tensor_list); 39 for (int i = 0; i < size; ++i) { 40 PyObject* item = PyList_GetItem(py_tensor_list, i); 41 TF_Output* input_ptr; 42 if (!SWIG_IsOK(SWIG_ConvertPtr(item, reinterpret_cast<void**>(&input_ptr), 43 SWIGTYPE_p_TF_Output, 0))) { 44 *error_msg = "expected Python list of wrapped TF_Output objects. " 45 "Found python list of something else."; 46 return false; 47 } 48 vec->push_back(*input_ptr); 49 } 50 return true; 51 } 52 53 // Helper function to convert a TF_Output to a wrapped TF_Output Python object. 54 PyObject* CreateWrappedTFOutput(TF_Output tf_output) { 55 // We used heap-allocated pointers in the Python runtime (this is what SWIG 56 // generates by default for functions returning TF_Output). 57 TF_Output* tf_output_ptr = new TF_Output(tf_output); 58 // Use SWIG_POINTER_OWN so the TF_Output* is deleted by Python. 59 return SWIG_NewPointerObj(tf_output_ptr, SWIGTYPE_p_TF_Output, 60 SWIG_POINTER_OWN); 61 } 62 63 // Helper function to convert a TF_Operation to a wrapped TF_Operation Python 64 // object. 65 PyObject* CreateWrappedTFOperation(TF_Operation* tf_operation) { 66 // No flags since operation is owned by TF_Graph. 67 return SWIG_NewPointerObj(tf_operation, SWIGTYPE_p_TF_Operation, 0); 68 } 69 70 // Helper function to convert a Python list of ints to a C++ vector of int64s 71 void PyInt64ListToVector(PyObject* py_int_seq, std::vector<int64_t>* vec) { 72 int size = PySequence_Fast_GET_SIZE(py_int_seq); 73 for (int i = 0; i < size; ++i) { 74 PyObject* item = PySequence_Fast_GET_ITEM(py_int_seq, i); 75 vec->push_back(PyInt_AsLong(item)); 76 } 77 } 78 79 %} 80 81 %include "tensorflow/python/client/tf_sessionrun_wrapper.i" 82 83 // Required to use PyArray_* functions. 84 %init %{ 85 tensorflow::ImportNumpy(); 86 %} 87 88 // TensorFlow version and GraphDef versions 89 %constant const char* __version__ = TF_VERSION_STRING; 90 %constant int GRAPH_DEF_VERSION = TF_GRAPH_DEF_VERSION; 91 %constant int GRAPH_DEF_VERSION_MIN_CONSUMER = TF_GRAPH_DEF_VERSION_MIN_CONSUMER; 92 %constant int GRAPH_DEF_VERSION_MIN_PRODUCER = TF_GRAPH_DEF_VERSION_MIN_PRODUCER; 93 94 // Git version information 95 %constant const char* __git_version__ = tf_git_version(); 96 97 // Compiler 98 %constant const char* __compiler_version__ = tf_compiler_version(); 99 100 // _GLIBCXX_USE_CXX11_ABI flag value 101 %constant const int __cxx11_abi_flag__ = tf_cxx11_abi_flag(); 102 103 // Flag indicating whether the build is monolithic 104 %constant const int __monolithic_build__ = tf_monolithic_build(); 105 106 // Release the Python GIL for the duration of most methods. 107 %exception { 108 Py_BEGIN_ALLOW_THREADS; 109 $action 110 Py_END_ALLOW_THREADS; 111 } 112 113 // The target input to TF_SetTarget() is passed as a null-terminated 114 // const char*. 115 %typemap(in) (const char* target) { 116 $1 = PyBytes_AsString($input); 117 if (!$1) { 118 // Python has raised an error. 119 SWIG_fail; 120 } 121 } 122 123 // Constants used by TensorHandle (get_session_handle). 124 %constant const char* TENSOR_HANDLE_KEY = tensorflow::SessionState::kTensorHandleResourceTypeName; 125 126 // Convert TF_OperationName output to unicode python string 127 %typemap(out) const char* TF_OperationName { 128 $result = PyUnicode_FromString($1); 129 } 130 131 // Convert TF_OperationOpType output to unicode python string 132 %typemap(out) const char* TF_OperationOpType { 133 $result = PyUnicode_FromString($1); 134 } 135 136 // Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers 137 %typemap(out) int64_t { 138 $result = PyInt_FromLong($1); 139 } 140 141 // We use TF_OperationGetControlInputs_wrapper instead of 142 // TF_OperationGetControlInputs 143 %ignore TF_OperationGetControlInputs; 144 %unignore TF_OperationGetControlInputs_wrapper; 145 // See comment for "%noexception TF_SessionRun_wrapper;" 146 %noexception TF_OperationGetControlInputs_wrapper; 147 148 // Build a Python list of TF_Operation* and return it. 149 %typemap(out) std::vector<TF_Operation*> tensorflow::TF_OperationGetControlInputs_wrapper { 150 $result = PyList_New($1.size()); 151 if (!$result) { 152 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 153 } 154 155 for (size_t i = 0; i < $1.size(); ++i) { 156 PyList_SET_ITEM($result, i, CreateWrappedTFOperation($1[i])); 157 } 158 } 159 160 %ignore TF_OperationOutputConsumers; 161 %unignore TF_OperationOutputConsumers_wrapper; 162 // See comment for "%noexception TF_SessionRun_wrapper;" 163 %noexception TF_OperationGetOutputConsumers_wrapper; 164 165 // Build a Python list of unicode strings and return it. (Operation names are 166 // always represented as unicode.) 167 %typemap(out) std::vector<const char*> 168 tensorflow::TF_OperationOutputConsumers_wrapper { 169 $result = PyList_New($1.size()); 170 if (!$result) { 171 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 172 } 173 174 for (size_t i = 0; i < $1.size(); ++i) { 175 PyList_SET_ITEM($result, i, PyUnicode_FromString($1[i])); 176 } 177 } 178 179 %unignore GetOperationInputs; 180 // See comment for "%noexception TF_SessionRun_wrapper;" 181 %noexception GetOperationInputs; 182 183 // Build a Python list of TF_Outputs and return it. 184 // TODO(skyewm): is there some way to generalize this pattern? Maybe a macro? 185 %typemap(out) std::vector<TF_Output> tensorflow::GetOperationInputs { 186 $result = PyList_New($1.size()); 187 if (!$result) { 188 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 189 } 190 191 // Unwrap the generated SwigValueWrapper<std::vector<TF_Output>> 192 const std::vector<TF_Output>& tf_outputs = $1; 193 for (size_t i = 0; i < tf_outputs.size(); ++i) { 194 PyList_SET_ITEM($result, i, CreateWrappedTFOutput(tf_outputs[i])); 195 } 196 } 197 198 %ignore TF_ImportGraphDefResultsMissingUnusedInputMappings; 199 %unignore TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper; 200 // See comment for "%noexception TF_SessionRun_wrapper;" 201 %noexception TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper; 202 203 %typemap(out) std::vector<string> 204 TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ 205 $result = PyList_New($1.size()); 206 if (!$result) { 207 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 208 } 209 for (size_t i = 0; i < $1.size(); ++i) { 210 const string& input_str = $1[i]; 211 PyList_SET_ITEM($result, i, PyBytes_FromStringAndSize(input_str.data(), 212 input_str.size())); 213 } 214 } 215 216 //////////////////////////////////////////////////////////////////////////////// 217 // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() 218 //////////////////////////////////////////////////////////////////////////////// 219 220 // Converts a python list of strings to NameVector. 221 // Has multiple users including feeds/fetches names and function output names 222 %typemap(in) const tensorflow::NameVector& ( 223 tensorflow::NameVector temp, 224 tensorflow::Safe_PyObjectPtr temp_string_list( 225 tensorflow::make_safe(static_cast<PyObject*>(nullptr)))) { 226 if (!PyList_Check($input)) { 227 SWIG_exception_fail( 228 SWIG_TypeError, 229 tensorflow::strings::Printf( 230 "Expected a python list for conversion " 231 "to tensorflow::NameVector but got %s", 232 Py_TYPE($input)->tp_name).c_str()); 233 } 234 235 Py_ssize_t len = PyList_Size($input); 236 237 temp_string_list = tensorflow::make_safe(PyList_New(len)); 238 if (!temp_string_list) { 239 SWIG_exception_fail( 240 SWIG_MemoryError, 241 tensorflow::strings::Printf("Failed to create a list of size %zd", 242 len).c_str()); 243 } 244 245 for (Py_ssize_t i = 0; i < len; ++i) { 246 PyObject* elem = PyList_GetItem($input, i); 247 if (!elem) { 248 SWIG_fail; 249 } 250 251 // Keep a reference to the string in case the incoming list is modified. 252 PyList_SET_ITEM(temp_string_list.get(), i, elem); 253 Py_INCREF(elem); 254 255 char* string_elem = PyBytes_AsString(elem); 256 if (!string_elem) { 257 SWIG_exception_fail( 258 SWIG_TypeError, 259 tensorflow::strings::Printf( 260 "Element %zd was of type %s instead of a string", 261 i, Py_TYPE(elem)->tp_name).c_str()); 262 } 263 264 // TODO(mrry): Avoid copying the fetch name in, if this impacts performance. 265 temp.push_back(string_elem); 266 } 267 $1 = &temp; 268 } 269 270 // Define temporaries for the argout outputs. 271 %typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values ( 272 tensorflow::PyObjectVector temp) { 273 $1 = &temp; 274 } 275 // TODO(iga): move this and the corresponding typemap(argout) to 276 // tf_sessionrun_wrapper.i once we get rid of this code for DeprecatedSession. 277 %typemap(in, numinputs=0) char** out_handle ( 278 char* temp) { 279 $1 = &temp; 280 } 281 282 // Build a Python list of outputs and return it. 283 %typemap(argout) tensorflow::PyObjectVector* out_values { 284 std::vector<tensorflow::Safe_PyObjectPtr> out_values_safe; 285 for (size_t i = 0; i < $1->size(); ++i) { 286 out_values_safe.emplace_back(tensorflow::make_safe($1->at(i))); 287 } 288 289 $result = PyList_New($1->size()); 290 if (!$result) { 291 SWIG_exception_fail( 292 SWIG_MemoryError, 293 tensorflow::strings::Printf("Failed to create a list of size %zd", 294 $1->size()).c_str()); 295 } 296 297 for (size_t i = 0; i < $1->size(); ++i) { 298 PyList_SET_ITEM($result, i, $1->at(i)); 299 out_values_safe[i].release(); 300 } 301 } 302 303 // Return the handle as a python string object. 304 %typemap(argout) char** out_handle { 305 %#if PY_MAJOR_VERSION < 3 306 $result = PyString_FromStringAndSize( 307 %#else 308 $result = PyUnicode_FromStringAndSize( 309 %#endif 310 *$1, *$1 == nullptr ? 0 : strlen(*$1)); 311 delete[] *$1; 312 } 313 314 //////////////////////////////////////////////////////////////////////////////// 315 // END TYPEMAPS FOR tensorflow::TF_Run_wrapper() 316 //////////////////////////////////////////////////////////////////////////////// 317 318 // Typemap for TF_Status* inputs that automatically unwraps a ScopedTFStatus. 319 // This can also handle a wrapped TF_Status* input. 320 %typemap(in) (TF_Status*) { 321 PyObject* wrapped_tf_status; 322 if (strcmp(Py_TYPE($input)->tp_name, "ScopedTFStatus") == 0) { 323 DCHECK(PyObject_HasAttrString($input, "status")) 324 << "ScopedTFStatus.status not found! Do you need to modify " 325 "tf_session.i?"; 326 wrapped_tf_status = PyObject_GetAttrString($input, "status"); 327 } else { 328 // Assume wrapped TF_Status* 329 wrapped_tf_status = $input; 330 } 331 DCHECK_EQ(strcmp(Py_TYPE(wrapped_tf_status)->tp_name, "SwigPyObject"), 0) 332 << Py_TYPE(wrapped_tf_status)->tp_name; 333 334 // The following is the default SWIG code generated for TF_Status* 335 void* tf_status = nullptr; 336 int r = SWIG_ConvertPtr(wrapped_tf_status, &tf_status, 337 $descriptor(TF_Status*), 0 | 0); 338 if (!SWIG_IsOK(r)) { 339 SWIG_exception_fail( 340 SWIG_ArgError(r), 341 "in method '_TF_DeleteStatus', argument 1 of type 'TF_Status *'"); 342 } 343 $1 = reinterpret_cast<TF_Status*>(tf_status); 344 } 345 346 // Typemap for functions that return a TF_Buffer struct. This typemap creates a 347 // Python string from the TF_Buffer and returns it. The TF_Buffer.data string 348 // is not expected to be NULL-terminated, and TF_Buffer.length does not count 349 // the terminator. 350 %typemap(out) TF_Buffer (TF_GetOpList,TF_GetBuffer) { 351 $result = PyBytes_FromStringAndSize( 352 reinterpret_cast<const char*>($1.data), $1.length); 353 } 354 355 // Converts input Python list of wrapped TF_Outputs into a single array 356 %typemap(in) (const TF_Output* inputs, int num_inputs) 357 (std::vector<TF_Output> inputs) { 358 string error_msg; 359 if (!PyTensorListToVector($input, &inputs, &error_msg)) { 360 SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str()); 361 } 362 $1 = inputs.data(); 363 $2 = inputs.size(); 364 } 365 366 // Typemaps for TF_ImportGraphDefResultsReturnOutputs 367 %typemap(in, numinputs=0) (int* num_outputs, TF_Output** outputs) 368 (int num_outputs, TF_Output* outputs) { 369 $1 = &num_outputs; 370 $2 = &outputs; 371 } 372 373 %typemap(argout) (int* num_outputs, TF_Output** outputs) { 374 $result = PyList_New(*$1); 375 if (!$result) { 376 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 377 } 378 int num_outputs = *$1; 379 TF_Output* outputs = *$2; 380 for (int i = 0; i < num_outputs; ++i) { 381 PyList_SET_ITEM($result, i, CreateWrappedTFOutput(outputs[i])); 382 } 383 } 384 385 // Typemaps for TF_ImportGraphDefResultsReturnOperations 386 %typemap(in, numinputs=0) (int* num_opers, TF_Operation*** opers) 387 (int num_opers, TF_Operation** opers) { 388 $1 = &num_opers; 389 $2 = &opers; 390 } 391 392 %typemap(argout) (int* num_opers, TF_Operation*** opers) { 393 $result = PyList_New(*$1); 394 if (!$result) { 395 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 396 } 397 int num_opers = *$1; 398 TF_Operation** opers = *$2; 399 for (int i = 0; i < num_opers; ++i) { 400 PyList_SET_ITEM($result, i, CreateWrappedTFOperation(opers[i])); 401 } 402 } 403 404 // Typemaps for TF_GraphNextOperation(). 405 %typemap(in) size_t* pos (size_t pos) { 406 pos = PyLong_AsUnsignedLong($input); 407 $1 = &pos; 408 } 409 410 // Returns a (TF_Operation*, int pos) tuple. 411 %typemap(argout) size_t* pos { 412 PyObject* new_result = PyTuple_New(2); 413 if (!new_result) { 414 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple"); 415 } 416 // Steals $result reference 417 PyTuple_SET_ITEM(new_result, 0, $result); 418 PyTuple_SET_ITEM(new_result, 1, PyLong_FromSize_t(*$1)); 419 $result = new_result; 420 } 421 422 // TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams, 423 // skip for now 424 %ignore TF_WhileParams; 425 %ignore TF_NewWhile; 426 %ignore TF_FinishWhile; 427 %ignore TF_AbortWhile; 428 429 // These are defined below, avoid duplicate definitions 430 %ignore TF_Run; 431 %ignore TF_PRun; 432 %ignore TF_PRunSetup; 433 434 // We use TF_SessionRun_wrapper instead of TF_SessionRun 435 %ignore TF_SessionRun; 436 %unignore TF_SessionRun_wrapper; 437 // The %exception block above releases the Python GIL for the length of each 438 // wrapped method. We disable this behavior for TF_SessionRun_wrapper because it 439 // uses Python method(s) that expect the GIL to be held (at least 440 // PyArray_Return, maybe others). 441 %noexception TF_SessionRun_wrapper; 442 443 // We use TF_SessionPRunSetup_wrapper instead of TF_SessionPRunSetup 444 %ignore TF_SessionPRunSetup; 445 %unignore TF_SessionPRunSetup_wrapper; 446 // See comment for "%noexception TF_SessionRun_wrapper;" 447 %noexception TF_SessionPRunSetup_wrapper; 448 449 // We use TF_SessionPRun_wrapper instead of TF_SessionPRun 450 %ignore TF_SessionPRun; 451 %unignore TF_SessionPRun_wrapper; 452 // See comment for "%noexception TF_SessionRun_wrapper;" 453 %noexception TF_SessionPRun_wrapper; 454 455 %rename("_TF_SetTarget") TF_SetTarget; 456 %rename("_TF_SetConfig") TF_SetConfig; 457 %rename("_TF_NewSessionOptions") TF_NewSessionOptions; 458 459 %include "tensorflow/c/c_api.h" 460 %include "tensorflow/c/python_api.h" 461 462 463 %ignoreall 464 %insert("python") %{ 465 def TF_NewSessionOptions(target=None, config=None): 466 # NOTE: target and config are validated in the session constructor. 467 opts = _TF_NewSessionOptions() 468 if target is not None: 469 _TF_SetTarget(opts, target) 470 if config is not None: 471 from tensorflow.python.framework import errors 472 with errors.raise_exception_on_not_ok_status() as status: 473 config_str = config.SerializeToString() 474 _TF_SetConfig(opts, config_str, status) 475 return opts 476 %} 477 478 // Include the wrapper for TF_Run from tf_session_helper.h. 479 480 // The %exception block above releases the Python GIL for the length 481 // of each wrapped method. We disable this behavior for TF_Run 482 // because it uses the Python allocator. 483 %noexception tensorflow::TF_Run_wrapper; 484 %rename(TF_Run) tensorflow::TF_Run_wrapper; 485 %unignore tensorflow; 486 %unignore TF_Run; 487 %unignore EqualGraphDefWrapper; 488 %unignore EqualAttrValueWrapper; 489 490 // Include the wrapper for TF_PRunSetup from tf_session_helper.h. 491 492 // The %exception block above releases the Python GIL for the length 493 // of each wrapped method. We disable this behavior for TF_PRunSetup 494 // because it uses the Python allocator. 495 %noexception tensorflow::TF_PRunSetup_wrapper; 496 %rename(TF_PRunSetup) tensorflow::TF_PRunSetup_wrapper; 497 %unignore tensorflow; 498 %unignore TF_PRunSetup; 499 500 // Include the wrapper for TF_PRun from tf_session_helper.h. 501 502 // The %exception block above releases the Python GIL for the length 503 // of each wrapped method. We disable this behavior for TF_PRun 504 // because it uses the Python allocator. 505 %noexception tensorflow::TF_PRun_wrapper; 506 %rename(TF_PRun) tensorflow::TF_PRun_wrapper; 507 %unignore tensorflow; 508 %unignore TF_PRun; 509 510 %unignore tensorflow::TF_Reset_wrapper; 511 %insert("python") %{ 512 def TF_Reset(target, containers=None, config=None): 513 from tensorflow.python.framework import errors 514 opts = TF_NewSessionOptions(target=target, config=config) 515 try: 516 with errors.raise_exception_on_not_ok_status() as status: 517 TF_Reset_wrapper(opts, containers, status) 518 finally: 519 TF_DeleteSessionOptions(opts) 520 %} 521 522 // We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction 523 %ignore TF_GraphToFunction; 524 // TF_GraphToFunction_wrapper does not use any Python methods and 525 // does not require GIL to be held. 526 %unignore TF_GraphToFunction_wrapper; 527 528 // $input is a Python list of wrapped TF_Operations 529 %typemap(in) (const std::vector<TF_Operation*>* opers) 530 (std::vector<TF_Operation*> opers) { 531 if ($input != Py_None) { 532 if (!PyList_Check($input)) { 533 SWIG_exception_fail(SWIG_TypeError, "$symname: expected list"); 534 } 535 size_t size = PyList_Size($input); 536 for (int i = 0; i < size; ++i) { 537 PyObject* item = PyList_GetItem($input, i); 538 TF_Operation* oper_ptr; 539 SWIG_ConvertPtr(item, reinterpret_cast<void**>(&oper_ptr), 540 $descriptor(TF_Operation*), 0); 541 opers.push_back(oper_ptr); 542 } 543 $1 = &opers; 544 } else { 545 $1 = nullptr; 546 } 547 } 548 549 // Typemaps for TF_GraphGetTensorShapeHelper. 550 551 // Convert from C++ integer vector to Python list of ints. 552 %typemap(out) tensorflow::gtl::InlinedVector<int64_t, 6> 553 tensorflow::TF_GraphGetTensorShapeHelper { 554 $result = PyList_New($1.size()); 555 if (!$result) { 556 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 557 } 558 559 for (size_t i = 0; i < $1.size(); ++i) { 560 PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); 561 } 562 } 563 564 %typemap(in, numinputs=0) bool* unknown_shape (bool temp) { 565 $1=&temp; 566 } 567 568 // Returns a (list(int), bool) tuple. 569 %typemap(argout) bool* unknown_shape { 570 PyObject* new_result = PyTuple_New(2); 571 if (!new_result) { 572 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create tuple"); 573 } 574 // Steals $result reference 575 PyTuple_SET_ITEM(new_result, 0, $result); 576 PyTuple_SET_ITEM(new_result, 1, PyBool_FromLong(*$1)); 577 $result = new_result; 578 } 579 580 %unignore tensorflow; 581 %unignore TF_GraphGetTensorShapeHelper; 582 %ignore TF_GraphGetTensorShape; 583 584 // We use TF_GraphSetTensorShape_wrapper instead of 585 // TF_GraphSetTensorShape 586 %ignore TF_GraphSetTensorShape; 587 %unignore tensorflow; 588 %unignore TF_GraphSetTensorShape_wrapper; 589 590 // $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper 591 %typemap(in) (const std::vector<int64_t>& dims) 592 (std::vector<int64_t> dims_local){ 593 if ($input != Py_None) { 594 PyObject* py_int_seq = PySequence_Fast($input, tensorflow::strings::Printf( 595 "$symname: expected list but got %s ", 596 Py_TYPE($input)->tp_name).c_str()); 597 if (py_int_seq == nullptr) { 598 SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf( 599 "$symname: PySequence_Fast returned NULL.").c_str()); 600 } 601 PyInt64ListToVector(py_int_seq, &dims_local); 602 Py_DECREF(py_int_seq); 603 $1 = &dims_local; 604 } else { 605 $1 = nullptr; 606 } 607 } 608 609 // We use TF_GraphGetTensorShape_wrapper instead of 610 // TF_GraphGetTensorShape 611 %ignore TF_GraphGetTensorShape; 612 %unignore tensorflow; 613 %unignore TF_GraphGetTensorShape_wrapper; 614 615 // Build a Python list of ints and return it. 616 %typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper { 617 $result = PyList_New($1.size()); 618 if (!$result) { 619 SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); 620 } 621 622 for (size_t i = 0; i < $1.size(); ++i) { 623 PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); 624 } 625 } 626 627 // We use TF_GraphSetOutputHandleShapesAndTypes_wrapper instead of 628 // TF_GraphSetOutputHandleShapesAndTypes 629 %ignore TF_GraphSetOutputHandleShapesAndTypes; 630 %unignore tensorflow; 631 %unignore TF_GraphSetOutputHandleShapesAndTypes_wrapper; 632 633 // The space between the double angle brackets below looks extraneous, but 634 // our version of SWIG cannot parse ">>". 635 %typemap(in) (const std::vector<std::vector<int64_t> >& shapes) 636 (std::vector<std::vector<int64_t> > shapes_local){ 637 PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf( 638 "$symname: expected list but got %s ", 639 Py_TYPE($input)->tp_name).c_str()); 640 if (seq == nullptr) { 641 SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf( 642 "$symname: PySequence_Fast returned NULL.").c_str()); 643 } 644 645 int size = PySequence_Fast_GET_SIZE(seq); 646 if (size == 0) { 647 SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf( 648 "$symname: shapes list must be non-empty").c_str()); 649 } 650 651 for (int i = 0; i < size; ++i) { 652 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 653 std::vector<int64_t> dims; 654 if (item != Py_None) { 655 PyObject* py_int_seq = PySequence_Fast(item, tensorflow::strings::Printf( 656 "$symname: expected list but got %s ", 657 Py_TYPE($input)->tp_name).c_str()); 658 if (py_int_seq == nullptr) { 659 SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf( 660 "$symname: PySequence_Fast returned NULL.").c_str()); 661 } 662 PyInt64ListToVector(py_int_seq, &dims); 663 Py_DECREF(py_int_seq); 664 } 665 shapes_local.push_back(dims); 666 } 667 668 Py_DECREF(seq); 669 $1 = &shapes_local; 670 } 671 672 %typemap(in) (const std::vector<int>& ranks) 673 (std::vector<int> ranks_local){ 674 PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf( 675 "$symname: expected list but got %s ", 676 Py_TYPE($input)->tp_name).c_str()); 677 if (seq == nullptr) { 678 SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf( 679 "$symname: PySequence_Fast returned NULL.").c_str()); 680 } 681 682 int size = PySequence_Fast_GET_SIZE(seq); 683 if (size == 0) { 684 SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf( 685 "$symname: shapes list must be non-empty").c_str()); 686 } 687 688 for (int i = 0; i < size; ++i) { 689 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 690 ranks_local.push_back((int) PyInt_AsLong(item)); 691 } 692 693 Py_DECREF(seq); 694 $1 = &ranks_local; 695 } 696 697 %typemap(in) (const std::vector<TF_DataType>& types) 698 (std::vector<TF_DataType> types_local){ 699 PyObject* seq = PySequence_Fast($input, tensorflow::strings::Printf( 700 "$symname: expected list but got %s ", 701 Py_TYPE($input)->tp_name).c_str()); 702 if (seq == nullptr) { 703 SWIG_exception_fail(SWIG_RuntimeError, tensorflow::strings::Printf( 704 "$symname: PySequence_Fast returned NULL.").c_str()); 705 } 706 707 int size = PySequence_Fast_GET_SIZE(seq); 708 if (size == 0) { 709 SWIG_exception_fail(SWIG_ValueError, tensorflow::strings::Printf( 710 "$symname: shapes list must be non-empty").c_str()); 711 } 712 713 for (int i = 0; i < size; ++i) { 714 PyObject* item = PySequence_Fast_GET_ITEM(seq, i); 715 types_local.push_back((TF_DataType) PyInt_AsLong(item)); 716 } 717 718 Py_DECREF(seq); 719 $1 = &types_local; 720 } 721 722 %include "tensorflow/python/client/tf_session_helper.h" 723 724 %unignoreall 725