1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ 17 #define TENSORFLOW_C_C_API_INTERNAL_H_ 18 19 #include "tensorflow/c/c_api.h" 20 21 #include <list> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #ifndef __ANDROID__ 28 #include "tensorflow/core/framework/op_gen_lib.h" 29 #endif 30 #include "tensorflow/core/common_runtime/shape_refiner.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/tensor_shape.h" 33 #include "tensorflow/core/graph/graph.h" 34 #include "tensorflow/core/graph/graph_constructor.h" 35 #include "tensorflow/core/graph/node_builder.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/core/public/session.h" 40 41 namespace tensorflow { 42 class Device; 43 class DeviceMgr; 44 } // namespace tensorflow 45 46 // Internal structures used by the C API. These are likely to change and should 47 // not be depended on. 48 49 struct TF_Status { 50 tensorflow::Status status; 51 }; 52 53 struct TF_Tensor { 54 ~TF_Tensor(); 55 56 TF_DataType dtype; 57 tensorflow::TensorShape shape; 58 tensorflow::TensorBuffer* buffer; 59 }; 60 61 struct TF_SessionOptions { 62 tensorflow::SessionOptions options; 63 }; 64 65 struct TF_DeprecatedSession { 66 tensorflow::Session* session; 67 }; 68 69 struct TF_Library { 70 void* lib_handle; 71 TF_Buffer op_list; 72 }; 73 74 struct TF_Graph { 75 TF_Graph(); 76 77 tensorflow::mutex mu; 78 tensorflow::Graph graph GUARDED_BY(mu); 79 80 // Runs shape inference. 81 tensorflow::ShapeRefiner refiner GUARDED_BY(mu); 82 83 // Maps from name of an operation to the Node* in 'graph'. 84 std::unordered_map<tensorflow::string, tensorflow::Node*> name_map 85 GUARDED_BY(mu); 86 87 // The keys of this map are all the active sessions using this graph. 88 // Each value is the current "runnability" status of the corresponding 89 // session. Under normal conditions all statuses are Status::OK(), but 90 // if some operation is mutated after it was run by a session (this 91 // is detected in RecordMutation function), that session is no longer 92 // safe to run. Its status will contain the error that will be returned 93 // to the user, should she try running this session. 94 // 95 // Sessions are added to this map in TF_NewSession, and removed in 96 // TF_DeleteSession. 97 // TF_Graph may only / must be deleted when 98 // sessions.size() == 0 && delete_requested == true 99 tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions 100 GUARDED_BY(mu); 101 bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph 102 103 // Used to link graphs contained in TF_WhileParams to the parent graph that 104 // will eventually contain the full while loop. 105 TF_Graph* parent; 106 TF_Output* parent_inputs; 107 }; 108 109 struct TF_OperationDescription { 110 TF_OperationDescription(TF_Graph* g, const char* op_type, 111 const char* node_name) 112 : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} 113 114 tensorflow::NodeBuilder node_builder; 115 TF_Graph* graph; 116 std::set<tensorflow::string> colocation_constraints; 117 }; 118 119 struct TF_Operation { 120 tensorflow::Node node; 121 }; 122 123 struct TF_Session { 124 TF_Session(tensorflow::Session* s, TF_Graph* g); 125 126 tensorflow::Session* session; 127 TF_Graph* graph; 128 129 tensorflow::mutex mu; 130 int last_num_graph_nodes; 131 132 // NOTE(ashankar): Experimental fields to help keep the 133 // buffers of a TF_Tensor pinned in device memory. 134 const tensorflow::DeviceMgr* device_mgr; // Owned by session. 135 std::vector<tensorflow::Device*> devices; // Owned by device_mgr. 136 }; 137 138 struct TF_ImportGraphDefOptions { 139 tensorflow::ImportGraphDefOptions opts; 140 141 // Backing memory for TensorId fields in opts. 142 // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. 143 std::list<tensorflow::string> tensor_id_data; 144 }; 145 146 struct TF_ImportGraphDefResults { 147 std::vector<TF_Output> return_tensors; 148 std::vector<TF_Operation*> return_nodes; 149 std::vector<const char*> missing_unused_key_names; 150 std::vector<int> missing_unused_key_indexes; 151 152 // Backing memory for missing_unused_key_names values. 153 std::list<tensorflow::string> missing_unused_key_names_data; 154 }; 155 156 struct TF_DeviceList { 157 std::vector<tensorflow::DeviceAttributes> response; 158 }; 159 160 struct TF_Function { 161 tensorflow::FunctionDef fdef; 162 }; 163 164 struct TF_ApiDefMap { 165 explicit TF_ApiDefMap(const tensorflow::OpList& op_list) 166 : 167 #ifndef __ANDROID__ 168 api_def_map(op_list), 169 #endif 170 update_docs_called(false) { 171 } 172 173 #ifndef __ANDROID__ 174 tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); 175 #endif 176 bool update_docs_called GUARDED_BY(lock); 177 tensorflow::mutex lock; 178 }; 179 180 namespace tensorflow { 181 182 class TensorCApi { 183 public: 184 static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } 185 static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, 186 TensorBuffer* buf) { 187 return Tensor(static_cast<DataType>(type), shape, buf); 188 } 189 }; 190 191 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); 192 193 TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); 194 195 Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); 196 197 // Set the shapes and types of the output's handle. 198 // 199 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must 200 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the 201 // rank is known), then it must be equal to the length of `shapes[i]`; if 202 // `ranks[i] == 1`, then `shapes[i]` may be nullptr. 203 // 204 // TODO(akshayka): Implement a corresponding getter method. 205 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, 206 int num_shapes_and_types, 207 const int64_t** shapes, 208 const int* ranks, 209 const TF_DataType* types, 210 TF_Status* status); 211 212 void RecordMutation(TF_Graph* graph, const TF_Operation& op, 213 const char* mutation_type); 214 215 } // end namespace tensorflow 216 217 #endif // TENSORFLOW_C_C_API_INTERNAL_H_ 218