1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_C_TEST_UTIL_H_ 17 #define TENSORFLOW_C_C_TEST_UTIL_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #include <vector> 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/graph.pb.h" 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/platform/test.h" 27 28 using ::tensorflow::string; 29 30 typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> 31 unique_tensor_ptr; 32 33 // Create a tensor with values of type TF_INT8 provided by `values`. 34 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); 35 36 // Create a tensor with values of type TF_INT32 provided by `values`. 37 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, 38 const int32_t* values); 39 40 // Create 1 dimensional tensor with values from `values` 41 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values); 42 43 TF_Tensor* Int32Tensor(int32_t v); 44 45 TF_Tensor* DoubleTensor(double v); 46 47 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, 48 const char* name = "feed"); 49 50 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, 51 const char* name = "const"); 52 53 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, 54 const char* name = "scalar"); 55 56 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, 57 const char* name = "scalar"); 58 59 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 60 TF_Status* s, const char* name = "add"); 61 62 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 63 TF_Status* s, const char* name = "add"); 64 65 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, 66 TF_Graph* graph, TF_Operation* ctrl_op, 67 TF_Status* s, const char* name = "add"); 68 69 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, 70 const char* name = "add"); 71 72 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 73 TF_Status* s, const char* name = "min"); 74 75 // If `op_device` is non-empty, set the created op on that device. 76 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 77 const string& op_device, TF_Status* s, 78 const char* name = "min"); 79 80 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, 81 const char* name = "neg"); 82 83 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); 84 85 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, 86 TF_Graph* graph, TF_Status* s); 87 88 // Split `input` along the first dimension into 3 tensors 89 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, 90 const char* name = "split3"); 91 92 bool IsPlaceholder(const tensorflow::NodeDef& node_def); 93 94 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); 95 96 bool IsAddN(const tensorflow::NodeDef& node_def, int n); 97 98 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input); 99 100 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); 101 102 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); 103 104 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); 105 106 bool GetAttrValue(TF_Operation* oper, const char* attr_name, 107 tensorflow::AttrValue* attr_value, TF_Status* s); 108 109 // Returns a sorted vector of std::pair<function_name, gradient_func> from 110 // graph_def.library().gradient() 111 std::vector<std::pair<string, string>> GetGradDefs( 112 const tensorflow::GraphDef& graph_def); 113 114 // Returns a sorted vector of names contained in `grad_def` 115 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def); 116 117 class CSession { 118 public: 119 CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false); 120 explicit CSession(TF_Session* session); 121 122 ~CSession(); 123 124 void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs); 125 void SetOutputs(std::initializer_list<TF_Operation*> outputs); 126 void SetOutputs(const std::vector<TF_Output>& outputs); 127 void SetTargets(std::initializer_list<TF_Operation*> targets); 128 129 void Run(TF_Status* s); 130 131 void CloseAndDelete(TF_Status* s); 132 133 TF_Tensor* output_tensor(int i) { return output_values_[i]; } 134 135 TF_Session* mutable_session() { return session_; } 136 137 private: 138 void DeleteInputValues(); 139 void ResetOutputValues(); 140 141 TF_Session* session_; 142 std::vector<TF_Output> inputs_; 143 std::vector<TF_Tensor*> input_values_; 144 std::vector<TF_Output> outputs_; 145 std::vector<TF_Tensor*> output_values_; 146 std::vector<TF_Operation*> targets_; 147 }; 148 149 #endif // TENSORFLOW_C_C_TEST_UTIL_H_ 150