Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 %include "tensorflow/python/lib/core/strings.i"
     17 %include "tensorflow/python/platform/base.i"
     18 
     19 %{
     20 #include "tensorflow/c/checkpoint_reader.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/python/lib/core/ndarray_tensor.h"
     23 #include "tensorflow/python/lib/core/py_func.h"
     24 #include "tensorflow/python/lib/core/safe_ptr.h"
     25 %}
     26 
     27 %typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& {
     28   tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
     29   for (auto v : *$1) {
     30 %#if PY_MAJOR_VERSION >= 3
     31     tensorflow::Safe_PyObjectPtr key(
     32         tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(),
     33             v.first.size())));
     34 %#else
     35     tensorflow::Safe_PyObjectPtr key(
     36         tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(),
     37             v.first.size())));
     38 %#endif
     39     if (!key) {
     40       SWIG_fail;
     41     }
     42     size_t dims = v.second.dims();
     43     tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyList_New(dims)));
     44     if (!value) {
     45       SWIG_fail;
     46     }
     47     for (size_t i = 0; i < dims; ++i) {
     48 %#if PY_MAJOR_VERSION >= 3
     49       tensorflow::Safe_PyObjectPtr dim_value(
     50           tensorflow::make_safe(PyLong_FromLong(v.second.dim_size(i))));
     51 %#else
     52       tensorflow::Safe_PyObjectPtr dim_value(
     53           tensorflow::make_safe(PyInt_FromLong(v.second.dim_size(i))));
     54 %#endif
     55       if (!dim_value) {
     56         SWIG_fail;
     57       }
     58       PyList_SET_ITEM(value.get(), i, dim_value.release());
     59     }
     60     if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
     61       SWIG_fail;
     62     } else {
     63       key.release();
     64       value.release();
     65     }
     66   }
     67 
     68   $result = output_map.release();
     69 }
     70 
     71 %typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToDataTypeMap& {
     72   tensorflow::Safe_PyObjectPtr output_map(tensorflow::make_safe(PyDict_New()));
     73   for (auto v : *$1) {
     74 %#if PY_MAJOR_VERSION >= 3
     75     tensorflow::Safe_PyObjectPtr key(
     76         tensorflow::make_safe(PyUnicode_FromStringAndSize(v.first.c_str(), v.first.size())));
     77 %#else
     78     tensorflow::Safe_PyObjectPtr key(
     79         tensorflow::make_safe(PyString_FromStringAndSize(v.first.c_str(), v.first.size())));
     80 %#endif
     81     if (!key) {
     82       SWIG_fail;
     83     }
     84 %#if PY_MAJOR_VERSION >= 3
     85     tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyLong_FromLong(v.second)));
     86 %#else
     87     tensorflow::Safe_PyObjectPtr value(tensorflow::make_safe(PyInt_FromLong(v.second)));
     88 %#endif
     89     if (!value) {
     90       SWIG_fail;
     91     }
     92     if (PyDict_SetItem(output_map.get(), key.get(), value.get()) == -1) {
     93       SWIG_fail;
     94     } else {
     95       key.release();
     96       value.release();
     97     }
     98   }
     99 
    100   $result = output_map.release();
    101 }
    102 
    103 %{
    104 static PyObject* CheckpointReader_GetTensor(
    105       tensorflow::checkpoint::CheckpointReader* reader,
    106       const string& name,
    107       TF_Status* out_status) {
    108   PyObject* py_obj = Py_None;
    109   std::unique_ptr<tensorflow::Tensor> tensor;
    110   reader->GetTensor(name, &tensor, out_status);
    111   if (TF_GetCode(out_status) == TF_OK) {
    112     tensorflow::Status status =
    113         tensorflow::ConvertTensorToNdarray(*tensor.get(), &py_obj);
    114     if (!status.ok()) {
    115       Set_TF_Status_from_Status(out_status, status);
    116     }
    117   }
    118   return py_obj;
    119 }
    120 %}
    121 
    122 // Wrap this function.
    123 PyObject* CheckpointReader_GetTensor(
    124     tensorflow::checkpoint::CheckpointReader* reader,
    125     const string& name,
    126     TF_Status* out_status);
    127 
    128 %ignoreall
    129 
    130 %unignore tensorflow;
    131 %unignore tensorflow::checkpoint;
    132 %unignore tensorflow::checkpoint::CheckpointReader;
    133 %unignore tensorflow::checkpoint::CheckpointReader::CheckpointReader;
    134 %unignore tensorflow::checkpoint::CheckpointReader::~CheckpointReader;
    135 %rename("debug_string") tensorflow::checkpoint::CheckpointReader::DebugString;
    136 %rename("get_variable_to_shape_map") tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap;
    137 %rename("_GetVariableToDataTypeMap") tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap;
    138 %rename("_HasTensor") tensorflow::checkpoint::CheckpointReader::HasTensor;
    139 %unignore CheckpointReader_GetTensor;
    140 
    141 %extend tensorflow::checkpoint::CheckpointReader {
    142 %insert("python") %{
    143   def get_variable_to_dtype_map(self):
    144     from tensorflow.python.framework import dtypes
    145     return {name: dtypes.DType(type_enum)
    146             for name, type_enum in self._GetVariableToDataTypeMap().items()}
    147 
    148   def has_tensor(self, tensor_str):
    149     from tensorflow.python.util import compat
    150     return self._HasTensor(compat.as_bytes(tensor_str))
    151 
    152   def get_tensor(self, tensor_str):
    153     from tensorflow.python.framework import errors
    154     with errors.raise_exception_on_not_ok_status() as status:
    155       from tensorflow.python.util import compat
    156       return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str),
    157                                         status)
    158 %}
    159 }
    160 
    161 %insert("python") %{
    162 def NewCheckpointReader(filepattern):
    163   from tensorflow.python.framework import errors
    164   with errors.raise_exception_on_not_ok_status() as status:
    165     from tensorflow.python.util import compat
    166     return CheckpointReader(compat.as_bytes(filepattern), status)
    167 
    168 NewCheckpointReader._tf_api_names = ['train.NewCheckpointReader']
    169 %}
    170 
    171 %include "tensorflow/c/checkpoint_reader.h"
    172 %unignoreall
    173