1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 %include "tensorflow/python/lib/core/strings.i" 17 %include "tensorflow/python/platform/base.i" 18 19 %{ 20 #include "tensorflow/c/checkpoint_reader.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/python/lib/core/ndarray_tensor.h" 23 #include "tensorflow/python/lib/core/py_func.h" 24 #include "tensorflow/python/lib/core/safe_ptr.h" 25 %} 26 27 %typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& { 28 tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New())); 29 for (auto v : *$1) { 30 %#if PY_MAJOR_VERSION >= 3 31 tensorflow::Safe_PyObjectPtr key( 32 tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), 33 v.first.size()))); 34 %#else 35 tensorflow::Safe_PyObjectPtr key( 36 tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), 37 v.first.size()))); 38 %#endif 39 if (!key) { 40 SWIG_fail; 41 } 42 size_t dims = v.second.dims(); 43 tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims))); 44 if (!value) { 45 SWIG_fail; 46 } 47 for (size_t i = 0; i < dims; ++i) { 48 %#if PY_MAJOR_VERSION >= 3 49 tensorflow::Safe_PyObjectPtr dim_value( 50 tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i)))); 51 %#else 52 tensorflow::Safe_PyObjectPtr dim_value( 53 tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i)))); 54 %#endif 55 if (!dim_value) { 56 SWIG_fail; 57 } 58 PyList_SET_ITEM(value.get(), i, dim_value.release()); 59 } 60 if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) { 61 SWIG_fail; 62 } else { 63 key.release(); 64 value.release(); 65 } 66 } 67 68 $result = output_map.release(); 69 } 70 71 %typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& { 72 tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New())); 73 for (auto v : *$1) { 74 %#if PY_MAJOR_VERSION >= 3 75 tensorflow::Safe_PyObjectPtr key( 76 tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size()))); 77 %#else 78 tensorflow::Safe_PyObjectPtr key( 79 tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size()))); 80 %#endif 81 if (!key) { 82 SWIG_fail; 83 } 84 %#if PY_MAJOR_VERSION >= 3 85 tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second))); 86 %#else 87 tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second))); 88 %#endif 89 if (!value) { 90 SWIG_fail; 91 } 92 if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) { 93 SWIG_fail; 94 } else { 95 key.release(); 96 value.release(); 97 } 98 } 99 100 $result = output_map.release(); 101 } 102 103 %{ 104 static PyObject* CheckpointReader_GetTensor( 105 tensorflow::checkpoint::CheckpointReader* reader, 106 const string& name, 107 TF_Status* out_status) { 108 PyObject* py_obj = Py_None; 109 std::unique_ptr<tensorflow::Tensor> tensor; 110 reader->GetTensor(name, &tensor, out_status); 111 if (TF_GetCode(out_status) == TF_OK) { 112 tensorflow::Status status = 113 tensorflow::ConvertTensorToNdarray(*tensor.get(), &py_obj); 114 if (!status.ok()) { 115 Set_TF_Status_from_Status(out_status, status); 116 } 117 } 118 return py_obj; 119 } 120 %} 121 122 // Wrap this function. 123 PyObject* CheckpointReader_GetTensor( 124 tensorflow::checkpoint::CheckpointReader* reader, 125 const string& name, 126 TF_Status* out_status); 127 128 %ignoreall 129 130 %unignore tensorflow; 131 %unignore tensorflow::checkpoint; 132 %unignore tensorflow::checkpoint::CheckpointReader; 133 %unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader; 134 %unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader; 135 %rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString; 136 %rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap; 137 %rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap; 138 %rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor; 139 %unignore CheckpointReader_GetTensor; 140 141 %extend tensorflow::checkpoint::CheckpointReader { 142 %insert("python") %{ 143 def get_variable_to_dtype_map(self): 144 from tensorflow.python.framework import dtypes 145 return {name: dtypes.DType(type_enum) 146 for name, type_enum in self._GetVariableToDataTypeMap().items()} 147 148 def has_tensor(self, tensor_str): 149 from tensorflow.python.util import compat 150 return self._HasTensor(compat.as_bytes(tensor_str)) 151 152 def get_tensor(self, tensor_str): 153 from tensorflow.python.framework import errors 154 with errors.raise_exception_on_not_ok_status() as status: 155 from tensorflow.python.util import compat 156 return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str), 157 status) 158 %} 159 } 160 161 %insert("python") %{ 162 def NewCheckpointReader(filepattern): 163 from tensorflow.python.framework import errors 164 with errors.raise_exception_on_not_ok_status() as status: 165 from tensorflow.python.util import compat 166 return CheckpointReader(compat.as_bytes(filepattern), status) 167 168 NewCheckpointReader._tf_api_names = ['train.NewCheckpointReader'] 169 %} 170 171 %include "tensorflow/c/checkpoint_reader.h" 172 %unignoreall 173