1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/python/util/util.h" 16 17 #include "tensorflow/core/lib/strings/strcat.h" 18 #include "tensorflow/core/platform/logging.h" 19 20 namespace tensorflow { 21 namespace swig { 22 23 namespace { 24 25 // Type object for collections.Sequence. This is set by RegisterSequenceClass. 26 PyObject* CollectionsSequenceType = nullptr; 27 28 bool WarnedThatSetIsNotSequence = false; 29 30 // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). 31 // Returns 0 otherwise. 32 // Returns -1 if an error occurred. 33 int IsSequenceHelper(PyObject* o) { 34 if (PyDict_Check(o)) return true; 35 if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { 36 LOG(WARNING) << "Sets are not currently considered sequences, " 37 "but this may change in the future, " 38 "so consider avoiding using them."; 39 WarnedThatSetIsNotSequence = true; 40 } 41 if (CollectionsSequenceType == nullptr) { 42 PyErr_SetString( 43 PyExc_RuntimeError, 44 tensorflow::strings::StrCat( 45 "collections.Sequence type has not been set. " 46 "Please call RegisterSequenceClass before using this module") 47 .c_str()); 48 return -1; 49 } 50 int is_instance = PyObject_IsInstance(o, CollectionsSequenceType); 51 if (is_instance == -1) return -1; 52 return static_cast<int>(is_instance != 0 && !PyBytes_Check(o) && 53 #if PY_MAJOR_VERSION < 3 54 !PyString_Check(o) && 55 #endif 56 !PyUnicode_Check(o)); 57 } 58 59 bool FlattenHelper(PyObject* nested, PyObject* list) { 60 // if nested is not a sequence, append itself and exit 61 int is_seq = IsSequenceHelper(nested); 62 if (is_seq == -1) return false; 63 if (!is_seq) { 64 return PyList_Append(list, nested) != -1; 65 } 66 67 // if nested if dictionary, sort it by key and recurse on each value 68 if (PyDict_Check(nested)) { 69 PyObject* keys = PyDict_Keys(nested); 70 if (PyList_Sort(keys) == -1) return false; 71 Py_ssize_t size = PyList_Size(keys); 72 for (Py_ssize_t i = 0; i < size; ++i) { 73 // We know that key and val will not be deleted because nested owns 74 // a reference to them and callers of flatten must not modify nested 75 // while the method is running. 76 PyObject* key = PyList_GET_ITEM(keys, i); 77 PyObject* val = PyDict_GetItem(nested, key); 78 if (Py_EnterRecursiveCall(" in Flatten")) { 79 Py_DECREF(keys); 80 return false; 81 } 82 FlattenHelper(val, list); 83 Py_LeaveRecursiveCall(); 84 } 85 Py_DECREF(keys); 86 return true; 87 } 88 89 // iterate and recurse 90 PyObject* item; 91 PyObject* iterator = PyObject_GetIter(nested); 92 while ((item = PyIter_Next(iterator)) != nullptr) { 93 FlattenHelper(item, list); 94 Py_DECREF(item); 95 } 96 Py_DECREF(iterator); 97 return true; 98 } 99 100 } // anonymous namespace 101 102 void RegisterSequenceClass(PyObject* sequence_class) { 103 if (!PyType_Check(sequence_class)) { 104 PyErr_SetString( 105 PyExc_TypeError, 106 tensorflow::strings::StrCat( 107 "Expecting a class definition for `collections.Sequence`. Got ", 108 Py_TYPE(sequence_class)->tp_name) 109 .c_str()); 110 return; 111 } 112 CollectionsSequenceType = sequence_class; 113 } 114 115 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } 116 117 PyObject* Flatten(PyObject* nested) { 118 PyObject* list = PyList_New(0); 119 if (FlattenHelper(nested, list)) { 120 return list; 121 } else { 122 Py_DECREF(list); 123 return nullptr; 124 } 125 } 126 } // namespace swig 127 } // namespace tensorflow 128