1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ 17 #define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ 18 19 // Must be included first 20 #include "tensorflow/python/lib/core/numpy.h" 21 22 #include "tensorflow/c/c_api.h" 23 #include "tensorflow/core/framework/graph.pb.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/gtl/inlined_vector.h" 27 28 namespace tensorflow { 29 30 // Container types for the various arguments and temporary values used 31 // in the wrapper. 32 33 // A NameVector is a vector of tensor or operation names, as borrowed 34 // C strings. 35 typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector; 36 37 // A PyObjectVector is a vector of borrowed pointers to PyObjects. 38 typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector; 39 40 // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. 41 typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector; 42 43 // Run the graph associated with the session starting with the 44 // supplied inputs[]. Regardless of success or failure, inputs[] are 45 // stolen by the implementation (i.e. the implementation will 46 // eventually call Py_DECREF on each array input). 47 // 48 // The PyObject* feed_dict must be a dictionary mapping strings to 49 // NumPy arrays. This function does not modify its reference count. 50 // 51 // On success, the tensors corresponding to output_names[0,noutputs-1] 52 // are placed in out_values[], and these outputs[] become the property 53 // of the caller (the caller must eventually call Py_DECREF on them). 54 // 55 // On failure, out_status contains a tensorflow::Status with an error 56 // message. 57 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, 58 PyObject* feed_dict, const NameVector& output_names, 59 const NameVector& target_nodes, TF_Status* out_status, 60 PyObjectVector* out_values, TF_Buffer* run_outputs); 61 62 // Set up the graph with the intended feeds and fetches for partial run. 63 // *out_handle is owned by the caller. 64 // 65 // On success, returns a handle that is used for subsequent PRun calls. 66 // 67 // On failure, out_status contains a tensorflow::Status with an error 68 // message. 69 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, 70 const NameVector& input_names, 71 const NameVector& output_names, 72 const NameVector& target_nodes, TF_Status* out_status, 73 const char** out_handle); 74 75 // Continue to run the graph with additional feeds and fetches. The 76 // execution state is uniquely identified by the handle. 77 // 78 // The PyObject* feed_dict must be a dictionary mapping strings to 79 // NumPy arrays. This function does not modify its reference count. 80 // 81 // On success, the tensors corresponding to output_names[0,noutputs-1] 82 // are placed in out_values[], and these outputs[] become the property 83 // of the caller (the caller must eventually call Py_DECREF on them). 84 // 85 // On failure, out_status contains a tensorflow::Status with an error 86 // message. 87 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle, 88 PyObject* feed_dict, const NameVector& output_names, 89 TF_Status* out_status, PyObjectVector* out_values); 90 91 // Wrapper for TF_Reset that converts the string vectors to character arrays. 92 void TF_Reset_wrapper(const TF_SessionOptions* opt, 93 const NameVector& containers, TF_Status* out_status); 94 95 // Convenience wrapper around EqualGraphDef to make it easier to wrap. 96 // Returns an explanation if a difference is found, or the empty string 97 // for no difference. 98 string EqualGraphDefWrapper(const string& actual, const string& expected); 99 100 // Convenience wrapper around AreAttrValuesEqual to make it easier to wrap. 101 // The actual and expected strings must correspond to a serialized binary 102 // representation of two AttrValue proto instances. 103 // Returns an explanation if a difference is found, or the empty string 104 // for no difference. 105 string EqualAttrValueWrapper(const string& actual, const string& expected); 106 107 // Gets shape from C API Graph object. 108 // 109 // If shape is known, returns shape vector where -1 means "unknown 110 // dimension". Sets unknown_shape to false. 111 // 112 // If shape is unknown, sets unknown_shape to true. 113 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( 114 TF_Graph* graph, TF_Output output, TF_Status* out_status, 115 bool* unknown_shape); 116 117 // Runs the graph associated with the session starting with the supplied inputs. 118 // On success, `py_outputs` is populated with a numpy ndarray for each output 119 // (the caller must decref these ndarrays, although this will likely be handled 120 // by the Python gc). `session`, `out_status`, and `py_outputs` must be 121 // non-null. `py_outputs` should be empty. 122 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, 123 const std::vector<TF_Output>& inputs, 124 const std::vector<PyObject*>& input_ndarrays, 125 const std::vector<TF_Output>& outputs, 126 const std::vector<TF_Operation*>& targets, 127 TF_Buffer* run_metadata, TF_Status* out_status, 128 std::vector<PyObject*>* py_outputs); 129 130 // Set up the graph with the intended feeds (inputs) and fetches (output) for 131 // a sequence of partial run calls. 132 // 133 // On success, returns a handle that can be used for subsequent PRun calls. The 134 // handle is owned by the caller and should be deleted with TF_DeletePRunHandle 135 // when it is no longer needed. 136 // 137 // On failure, out_status contains a tensorflow::Status with an error 138 // message. 139 void TF_SessionPRunSetup_wrapper(TF_Session* session, 140 const std::vector<TF_Output>& inputs, 141 const std::vector<TF_Output>& outputs, 142 const std::vector<TF_Operation*>& targets, 143 const char** out_handle, 144 TF_Status* out_status); 145 146 // Continue to run the graph with additional feeds and fetches. The 147 // execution state is uniquely identified by the handle. 148 // 149 // On success, `py_outputs` is populated with a numpy ndarray for each output 150 // (the caller must decref these ndarrays, although this will likely be handled 151 // by the Python gc). `session`, `handle`, `out_status`, and `py_outputs` must 152 // be non-null. `py_outputs` should be empty. 153 // 154 // On failure, out_status contains a tensorflow::Status with an error 155 // message. 156 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, 157 const std::vector<TF_Output>& inputs, 158 const std::vector<PyObject*>& input_ndarrays, 159 const std::vector<TF_Output>& outputs, 160 TF_Status* out_status, 161 std::vector<PyObject*>* py_outputs); 162 163 // Retrieves the inputs of this operation. 164 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper); 165 166 // Retrieves the control inputs of this operation. 167 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( 168 TF_Operation* oper); 169 170 // Retrieves the op names of the consumers of `oper_out`. The returned strings 171 // have the lifetime of the underlying TF_Graph. 172 std::vector<const char*> TF_OperationOutputConsumers_wrapper( 173 TF_Output oper_out); 174 175 // `opers` equaling NULL are converted to `nopers = -1`. 176 // `output_names` must be empty or have the same length as `outputs`. 177 TF_Function* TF_GraphToFunction_wrapper( 178 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name, 179 const std::vector<TF_Operation*>* opers, 180 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, 181 const NameVector& output_names, const TF_FunctionOptions* opts, 182 const char* description, TF_Status* out_status); 183 184 // Set the shapes and types for the output's handle. 185 // 186 // The sizes of 'shapes', 'ranks', and 'types' must be equal; `shapes[i]` 187 // contains the shape of the handle's i-th value, `ranks[i]` contains the i-th 188 // shape's rank, and `types[i]` contains the i-th value's dtype. If the i-th 189 // shape is unknown, then `ranks[i]` must be equal to -1. 190 // 191 // The space between the double angle brackets below looks extraneous, but 192 // our version of SWIG cannot parse ">>". 193 void TF_GraphSetOutputHandleShapesAndTypes_wrapper( 194 TF_Graph* graph, TF_Output output, 195 const std::vector<std::vector<int64_t> >& shapes, 196 const std::vector<int>& ranks, const std::vector<TF_DataType>& types, 197 TF_Status* status); 198 199 // Set the shape of output. If unknown is true, `num_dims` must be set to 200 // -1 and `dims` is set to nullptr. 201 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, 202 const std::vector<int64_t>& dims, 203 bool unknown_shape, TF_Status* status); 204 205 // Return the shape of output. `num_dims` should be the output of 206 // TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called. 207 std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, 208 TF_Output output, 209 int num_dims, 210 TF_Status* status); 211 212 // Returns the string representations of the missing unused input mappings. 213 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( 214 TF_ImportGraphDefResults* results); 215 216 } // namespace tensorflow 217 218 #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ 219