Home | History | Annotate | Download | only in client
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
     17 #define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
     18 
     19 // Must be included first
     20 #include "tensorflow/python/lib/core/numpy.h"
     21 
     22 #include "tensorflow/c/c_api.h"
     23 #include "tensorflow/core/framework/graph.pb.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     27 
     28 namespace tensorflow {
     29 
     30 // Container types for the various arguments and temporary values used
     31 // in the wrapper.
     32 
     33 // A NameVector is a vector of tensor or operation names, as borrowed
     34 // C strings.
     35 typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector;
     36 
     37 // A PyObjectVector is a vector of borrowed pointers to PyObjects.
     38 typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
     39 
     40 // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
     41 typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
     42 
     43 // Run the graph associated with the session starting with the
     44 // supplied inputs[].  Regardless of success or failure, inputs[] are
     45 // stolen by the implementation (i.e. the implementation will
     46 // eventually call Py_DECREF on each array input).
     47 //
     48 // The PyObject* feed_dict must be a dictionary mapping strings to
     49 // NumPy arrays. This function does not modify its reference count.
     50 //
     51 // On success, the tensors corresponding to output_names[0,noutputs-1]
     52 // are placed in out_values[], and these outputs[] become the property
     53 // of the caller (the caller must eventually call Py_DECREF on them).
     54 //
     55 // On failure, out_status contains a tensorflow::Status with an error
     56 // message.
     57 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
     58                     PyObject* feed_dict, const NameVector& output_names,
     59                     const NameVector& target_nodes, TF_Status* out_status,
     60                     PyObjectVector* out_values, TF_Buffer* run_outputs);
     61 
     62 // Set up the graph with the intended feeds and fetches for partial run.
     63 // *out_handle is owned by the caller.
     64 //
     65 // On success, returns a handle that is used for subsequent PRun calls.
     66 //
     67 // On failure, out_status contains a tensorflow::Status with an error
     68 // message.
     69 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
     70                           const NameVector& input_names,
     71                           const NameVector& output_names,
     72                           const NameVector& target_nodes, TF_Status* out_status,
     73                           const char** out_handle);
     74 
     75 // Continue to run the graph with additional feeds and fetches. The
     76 // execution state is uniquely identified by the handle.
     77 //
     78 // The PyObject* feed_dict must be a dictionary mapping strings to
     79 // NumPy arrays. This function does not modify its reference count.
     80 //
     81 // On success,  the tensors corresponding to output_names[0,noutputs-1]
     82 // are placed in out_values[], and these outputs[] become the property
     83 // of the caller (the caller must eventually call Py_DECREF on them).
     84 //
     85 // On failure,  out_status contains a tensorflow::Status with an error
     86 // message.
     87 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
     88                      PyObject* feed_dict, const NameVector& output_names,
     89                      TF_Status* out_status, PyObjectVector* out_values);
     90 
     91 // Wrapper for TF_Reset that converts the string vectors to character arrays.
     92 void TF_Reset_wrapper(const TF_SessionOptions* opt,
     93                       const NameVector& containers, TF_Status* out_status);
     94 
     95 // Convenience wrapper around EqualGraphDef to make it easier to wrap.
     96 // Returns an explanation if a difference is found, or the empty string
     97 // for no difference.
     98 string EqualGraphDefWrapper(const string& actual, const string& expected);
     99 
    100 // Convenience wrapper around AreAttrValuesEqual to make it easier to wrap.
    101 // The actual and expected strings must correspond to a serialized binary
    102 // representation of two AttrValue proto instances.
    103 // Returns an explanation if a difference is found, or the empty string
    104 // for no difference.
    105 string EqualAttrValueWrapper(const string& actual, const string& expected);
    106 
    107 // Gets shape from C API Graph object.
    108 //
    109 // If shape is known, returns shape vector where -1 means "unknown
    110 // dimension".  Sets unknown_shape to false.
    111 //
    112 // If shape is unknown, sets unknown_shape to true.
    113 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
    114     TF_Graph* graph, TF_Output output, TF_Status* out_status,
    115     bool* unknown_shape);
    116 
    117 // Runs the graph associated with the session starting with the supplied inputs.
    118 // On success, `py_outputs` is populated with a numpy ndarray for each output
    119 // (the caller must decref these ndarrays, although this will likely be handled
    120 // by the Python gc). `session`, `out_status`, and `py_outputs` must be
    121 // non-null. `py_outputs` should be empty.
    122 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
    123                            const std::vector<TF_Output>& inputs,
    124                            const std::vector<PyObject*>& input_ndarrays,
    125                            const std::vector<TF_Output>& outputs,
    126                            const std::vector<TF_Operation*>& targets,
    127                            TF_Buffer* run_metadata, TF_Status* out_status,
    128                            std::vector<PyObject*>* py_outputs);
    129 
    130 // Set up the graph with the intended feeds (inputs) and fetches (output) for
    131 // a sequence of partial run calls.
    132 //
    133 // On success, returns a handle that can be used for subsequent PRun calls. The
    134 // handle is owned by the caller and should be deleted with TF_DeletePRunHandle
    135 // when it is no longer needed.
    136 //
    137 // On failure, out_status contains a tensorflow::Status with an error
    138 // message.
    139 void TF_SessionPRunSetup_wrapper(TF_Session* session,
    140                                  const std::vector<TF_Output>& inputs,
    141                                  const std::vector<TF_Output>& outputs,
    142                                  const std::vector<TF_Operation*>& targets,
    143                                  const char** out_handle,
    144                                  TF_Status* out_status);
    145 
    146 // Continue to run the graph with additional feeds and fetches. The
    147 // execution state is uniquely identified by the handle.
    148 //
    149 // On success, `py_outputs` is populated with a numpy ndarray for each output
    150 // (the caller must decref these ndarrays, although this will likely be handled
    151 // by the Python gc). `session`, `handle`, `out_status`, and `py_outputs` must
    152 // be non-null. `py_outputs` should be empty.
    153 //
    154 // On failure, out_status contains a tensorflow::Status with an error
    155 // message.
    156 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
    157                             const std::vector<TF_Output>& inputs,
    158                             const std::vector<PyObject*>& input_ndarrays,
    159                             const std::vector<TF_Output>& outputs,
    160                             TF_Status* out_status,
    161                             std::vector<PyObject*>* py_outputs);
    162 
    163 // Retrieves the inputs of this operation.
    164 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper);
    165 
    166 // Retrieves the control inputs of this operation.
    167 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper(
    168     TF_Operation* oper);
    169 
    170 // Retrieves the op names of the consumers of `oper_out`. The returned strings
    171 // have the lifetime of the underlying TF_Graph.
    172 std::vector<const char*> TF_OperationOutputConsumers_wrapper(
    173     TF_Output oper_out);
    174 
    175 // `opers` equaling NULL are converted to `nopers = -1`.
    176 // `output_names` must be empty or have the same length as `outputs`.
    177 TF_Function* TF_GraphToFunction_wrapper(
    178     const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name,
    179     const std::vector<TF_Operation*>* opers,
    180     const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
    181     const NameVector& output_names, const TF_FunctionOptions* opts,
    182     const char* description, TF_Status* out_status);
    183 
    184 // Set the shapes and types for the output's handle.
    185 //
    186 // The sizes of 'shapes', 'ranks', and 'types' must be equal; `shapes[i]`
    187 // contains the shape of the handle's i-th value, `ranks[i]` contains the i-th
    188 // shape's rank, and `types[i]` contains the i-th value's dtype. If the i-th
    189 // shape is unknown, then `ranks[i]` must be equal to -1.
    190 //
    191 // The space between the double angle brackets below looks extraneous, but
    192 // our version of SWIG cannot parse ">>".
    193 void TF_GraphSetOutputHandleShapesAndTypes_wrapper(
    194     TF_Graph* graph, TF_Output output,
    195     const std::vector<std::vector<int64_t> >& shapes,
    196     const std::vector<int>& ranks, const std::vector<TF_DataType>& types,
    197     TF_Status* status);
    198 
    199 // Set the shape of output. If unknown is true, `num_dims` must be set to
    200 // -1 and `dims` is set to nullptr.
    201 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
    202                                     const std::vector<int64_t>& dims,
    203                                     bool unknown_shape, TF_Status* status);
    204 
    205 // Return the shape of output. `num_dims` should be the output of
    206 // TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called.
    207 std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
    208                                                     TF_Output output,
    209                                                     int num_dims,
    210                                                     TF_Status* status);
    211 
    212 // Returns the string representations of the missing unused input mappings.
    213 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
    214     TF_ImportGraphDefResults* results);
    215 
    216 }  // namespace tensorflow
    217 
    218 #endif  // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
    219