Home | History | Annotate | Download | only in c
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_C_C_TEST_UTIL_H_
     17 #define TENSORFLOW_C_C_TEST_UTIL_H_
     18 
     19 #include "tensorflow/c/c_api.h"
     20 
     21 #include <vector>
     22 #include "tensorflow/core/framework/attr_value.pb.h"
     23 #include "tensorflow/core/framework/graph.pb.h"
     24 #include "tensorflow/core/framework/node_def.pb.h"
     25 #include "tensorflow/core/framework/types.pb.h"
     26 #include "tensorflow/core/platform/test.h"
     27 
     28 using ::tensorflow::string;
     29 
     30 typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
     31     unique_tensor_ptr;
     32 
     33 // Create a tensor with values of type TF_INT8 provided by `values`.
     34 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
     35 
     36 // Create a tensor with values of type TF_INT32 provided by `values`.
     37 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
     38                        const int32_t* values);
     39 
     40 // Create 1 dimensional tensor with values from `values`
     41 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
     42 
     43 TF_Tensor* Int32Tensor(int32_t v);
     44 
     45 TF_Tensor* DoubleTensor(double v);
     46 
     47 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
     48                           const char* name = "feed");
     49 
     50 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
     51                     const char* name = "const");
     52 
     53 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
     54                           const char* name = "scalar");
     55 
     56 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
     57                           const char* name = "scalar");
     58 
     59 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
     60                   TF_Status* s, const char* name = "add");
     61 
     62 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
     63                          TF_Status* s, const char* name = "add");
     64 
     65 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
     66                                     TF_Graph* graph, TF_Operation* ctrl_op,
     67                                     TF_Status* s, const char* name = "add");
     68 
     69 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
     70                   const char* name = "add");
     71 
     72 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
     73                   TF_Status* s, const char* name = "min");
     74 
     75 // If `op_device` is non-empty, set the created op on that device.
     76 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
     77                             const string& op_device, TF_Status* s,
     78                             const char* name = "min");
     79 
     80 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
     81                   const char* name = "neg");
     82 
     83 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
     84 
     85 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
     86                             TF_Graph* graph, TF_Status* s);
     87 
     88 // Split `input` along the first dimension into 3 tensors
     89 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
     90                      const char* name = "split3");
     91 
     92 bool IsPlaceholder(const tensorflow::NodeDef& node_def);
     93 
     94 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
     95 
     96 bool IsAddN(const tensorflow::NodeDef& node_def, int n);
     97 
     98 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input);
     99 
    100 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
    101 
    102 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
    103 
    104 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
    105 
    106 bool GetAttrValue(TF_Operation* oper, const char* attr_name,
    107                   tensorflow::AttrValue* attr_value, TF_Status* s);
    108 
    109 // Returns a sorted vector of std::pair<function_name, gradient_func> from
    110 // graph_def.library().gradient()
    111 std::vector<std::pair<string, string>> GetGradDefs(
    112     const tensorflow::GraphDef& graph_def);
    113 
    114 // Returns a sorted vector of names contained in `grad_def`
    115 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def);
    116 
    117 class CSession {
    118  public:
    119   CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false);
    120   explicit CSession(TF_Session* session);
    121 
    122   ~CSession();
    123 
    124   void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs);
    125   void SetOutputs(std::initializer_list<TF_Operation*> outputs);
    126   void SetOutputs(const std::vector<TF_Output>& outputs);
    127   void SetTargets(std::initializer_list<TF_Operation*> targets);
    128 
    129   void Run(TF_Status* s);
    130 
    131   void CloseAndDelete(TF_Status* s);
    132 
    133   TF_Tensor* output_tensor(int i) { return output_values_[i]; }
    134 
    135   TF_Session* mutable_session() { return session_; }
    136 
    137  private:
    138   void DeleteInputValues();
    139   void ResetOutputValues();
    140 
    141   TF_Session* session_;
    142   std::vector<TF_Output> inputs_;
    143   std::vector<TF_Tensor*> input_values_;
    144   std::vector<TF_Output> outputs_;
    145   std::vector<TF_Tensor*> output_values_;
    146   std::vector<TF_Operation*> targets_;
    147 };
    148 
    149 #endif  // TENSORFLOW_C_C_TEST_UTIL_H_
    150