1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 17 #define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 18 19 #include <unordered_map> 20 #include <unordered_set> 21 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/graph/graph_constructor.h" 26 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" 27 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/platform/macros.h" 30 31 namespace tensorflow { 32 33 // RemoteFusedGraphExecuteUtils provides APIs to register and get builder 34 // functions for IRemoteFusedGraphExecutor. 35 class RemoteFusedGraphExecuteUtils { 36 public: 37 // TODO(satok): Use "_output_data_types" to share a spec with other ops 38 static constexpr const char* const ATTR_OUTPUT_DATA_TYPES = 39 "_default_remote_graph_output_data_types"; 40 // TODO(satok): Use "_output_shapes" to share a spec with other ops 41 static constexpr const char* const ATTR_OUTPUT_SHAPES = 42 "_default_remote_output_shapes"; 43 static constexpr const char* const 44 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO = 45 "serialized_remote_fused_graph_execute_info"; 46 static constexpr const char* const ATTR_NODE_TYPE = 47 "_remote_fused_graph_node_type"; 48 49 // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp. 50 static constexpr const char* const 51 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME = 52 "remote_fused_graph_executor_name"; 53 static constexpr const char* const 54 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME = 55 "remote_fused_graph_node_name"; 56 static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes"; 57 static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS = 58 "border_inputs"; 59 static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS = 60 "border_outputs"; 61 static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES = 62 "fused_op_types"; 63 static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR = 64 "fuse_by_executor"; 65 static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types"; 66 static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES = 67 "input_shapes"; 68 69 using ExecutorBuildFunc = std::function<Status( 70 std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>; 71 // Registrar class for IRemoteFusedGraphExecutor. 72 class ExecutorBuildRegistrar { 73 public: 74 ExecutorBuildRegistrar(const string& name, ExecutorBuildFunc func); 75 76 private: 77 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBuildRegistrar); 78 }; 79 using ExecutorBuildRegistry = std::map<string, ExecutorBuildFunc>; 80 81 using TensorShapeType = std::pair<DataType, TensorShape>; 82 using TensorShapeMap = std::unordered_multimap<string, // node name 83 std::pair<int, // port 84 TensorShapeType>>; 85 using ClusterInfo = std::tuple<std::unordered_set<string>, // node names 86 std::vector<string>, // border inputs 87 std::vector<string>>; // border outputs 88 89 // Return registered ExecutorBuildFunc for given name. 90 static const ExecutorBuildFunc* GetExecutorBuildFunc(const string& name); 91 92 // To determine shapes of output tensors of all nodes, dryrun the graph. 93 // This function supplies memory allocation information when loading 94 // the graph. This function is used to verify shape inference and actual 95 // output shape. 96 static Status DryRunInference( 97 const GraphDef& graph_def, 98 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 99 const std::vector<string>& output_node_names, 100 const bool initialize_by_zero, 101 std::vector<tensorflow::Tensor>* output_tensors); 102 103 // Dry run inference to obtain shapes for all nodes. 104 // CAVEAT: Do not add or modify output_tensors in output_tensor_info 105 // otherwise, address map may be broken by re-allocation inside 106 // std::vector. 107 static Status DryRunInferenceForAllNode( 108 const GraphDef& graph_def, 109 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 110 const bool initialize_by_zero, TensorShapeMap* tensor_shape_map); 111 112 static bool IsInputNode( 113 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 114 const string& node_name); 115 116 static void ConvertToTensorShapeMap( 117 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 118 const std::vector<string>& output_node_names, 119 const std::vector<tensorflow::Tensor>& output_tensors, 120 TensorShapeMap* tensor_shape_map); 121 122 static Status MakeTensorFromProto(const TensorProto& tensor_proto, 123 Tensor* tensor); 124 125 static bool AddOutputTensorShapeType(const std::vector<DataType>& data_types, 126 const std::vector<TensorShape>& shapes, 127 NodeDef* node_def); 128 129 static Status AddOutputTensorShapeTypeByTensorShapeMap( 130 const TensorShapeMap& tensor_shape_map, NodeDef* node_def); 131 132 static Status GetOutputTensorShapeType(AttrSlice attrs, 133 std::vector<DataType>* data_types, 134 std::vector<TensorShape>* shapes); 135 136 static bool GetOutputTensorShapeType(const GraphDef& graph_def, 137 const string& name_and_port, 138 DataType* data_type, TensorShape* shape); 139 140 static Status PropagateShapeInference( 141 const GraphDef& graph_def, 142 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 143 Graph* graph, ShapeRefiner* shape_refiner); 144 145 static Status BuildTensorShapeMapFromGraph(const Graph& graph, 146 const ShapeRefiner& shape_refiner, 147 TensorShapeMap* tensor_shape_map); 148 149 static const TensorShapeType* GetTensorShapeType( 150 const TensorShapeMap& tensor_shape_map, const string& node_name); 151 152 static const TensorShapeType* GetTensorShapeType( 153 const TensorShapeMap& tensor_shape_map, const string& node_name, 154 const int port); 155 156 static void BuildRemoteGraphInputsAndOutputsFromProto( 157 const RemoteFusedGraphExecuteInfo& proto, 158 std::vector<std::pair<string, Tensor>>* inputs, 159 std::vector<string>* outputs); 160 161 static Status BuildAndAddTensorShapes( 162 const std::vector<std::pair<string, Tensor>>& input_tensors, 163 const bool dry_run_inference, GraphDef* graph_def); 164 165 // Build remote fused graph execute info. 166 static Status BuildRemoteFusedGraphExecuteInfo( 167 const string& executor_name, const GraphDef& subgraph_def, 168 const std::vector<string>& inputs, const std::vector<string>& outputs, 169 const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info, 170 DataTypeVector* input_types, DataTypeVector* output_types); 171 172 // Build remote fused graph execute op node by fusing specified subgraph 173 // as remote fused graph execute info. 174 static Status BuildRemoteFusedGraphExecuteOpNode( 175 const string& node_name, const string& executor_name, 176 const GraphDef& subgraph_def, const std::vector<string>& inputs, 177 const std::vector<string>& outputs, const bool require_shape_type, 178 Graph* graph, Node** created_node); 179 180 // Build Identity node to forward remote graph node output. 181 static Status BuildIdentityOpNode(const string& node_name, 182 const string& input_node_name, 183 const int input_node_port, 184 const DataType dt, Graph* graph, 185 Node** created_node); 186 187 // Create clusters of given nodes. 188 static Status ClusterizeNodes(const std::unordered_set<string>& node_names, 189 const GraphDef& graph_def, 190 std::vector<ClusterInfo>* cluster_infos); 191 192 // Build GraphDef of a given cluster. 193 static Status BuildClusterSubgraphDef(const ClusterInfo& cluster, 194 const GraphDef& graph_def, 195 GraphDef* subgraph_def); 196 197 // Build a cluster by given border. 198 // CAVEAT: The border must be consistent for one cluster. 199 static Status BuildClusterByBorder(const std::vector<string>& border_inputs, 200 const std::vector<string>& border_outputs, 201 const GraphDef& graph_def, 202 ClusterInfo* cluster); 203 204 // Fuse one cluster into a newly created RemoteFusedGraphExecuteOp node. 205 // The subgraph is stored as a graph in RemoteFusedGraphExecuteInfo. 206 // CAVEAT1: This transform strips unvisited nodes with given outputs. 207 // CAVEAT2: If you want to use a graph output as a border output, 208 // that graph output node is replaced by an identity node. Therefore, 209 // the number of output of the node must be 1. 210 static Status FuseCluster(const GraphDef& input_graph_def, 211 const std::vector<string>& inputs, 212 const std::vector<string>& outputs, 213 const string& remote_fused_graph_node_name, 214 const ClusterInfo& cluster, 215 const string& remote_graph_executor_name, 216 const bool require_shape_type, 217 GraphDef* output_graph_def); 218 219 // Fuse subgraph of specified nodes. 220 static Status FuseRemoteGraphByNodeNames( 221 const GraphDef& input_graph_def, const std::vector<string>& inputs, 222 const std::vector<string>& outputs, 223 const string& remote_fused_graph_node_name_prefix, 224 const std::unordered_set<string>& subgraph_nodes, 225 const string& remote_fused_graph_executor_name, 226 const bool require_shape_type, GraphDef* output_graph_def); 227 228 // Fuse subgraph of specified border. 229 static Status FuseRemoteGraphByBorder( 230 const GraphDef& input_graph_def, const std::vector<string>& inputs, 231 const std::vector<string>& outputs, 232 const string& remote_fused_graph_node_name, 233 const std::vector<string>& border_inputs, 234 const std::vector<string>& border_outputs, 235 const string& remote_graph_executor_name, const bool require_shape_type, 236 GraphDef* output_graph_def); 237 238 // Fuse subgraph of specified op types. 239 static Status FuseRemoteGraphByOpTypes( 240 const GraphDef& input_graph_def, const std::vector<string>& inputs, 241 const std::vector<string>& outputs, 242 const string& remote_fused_graph_node_name_prefix, 243 const std::unordered_set<string>& fused_op_types, 244 const string& remote_fused_graph_executor_name, 245 const bool require_shape_type, GraphDef* output_graph_def); 246 247 // Place arguments to fuse remote graph. 248 static Status PlaceRemoteGraphArguments( 249 const std::vector<string>& inputs, const std::vector<string>& outputs, 250 const std::unordered_set<string>& fused_node_names, 251 const std::vector<string>& border_inputs, 252 const std::vector<string>& border_outputs, 253 const std::unordered_set<string>& fused_op_types, 254 const string& remote_fused_graph_node_name, 255 const string& remote_graph_executor_name, GraphDef* graph_def); 256 257 // Fuse remote graph by placed arguments. 258 static Status FuseRemoteGraphByPlacedArguments( 259 const GraphDef& input_graph_def, 260 const std::vector<std::pair<string, Tensor>>& input_tensors, 261 GraphDef* output_graph_def); 262 263 static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def, 264 const std::vector<string>& inputs, 265 const std::vector<string>& outputs, 266 const string& executor_name, 267 GraphDef* output_graph_def); 268 269 static bool IsFuseReady( 270 const GraphDef& input_graph_def, 271 const std::vector<std::pair<string, Tensor>>& input_tensors); 272 273 // Copy a byte array to a tensor data. Though tensor data must be 274 // updated with typed information in general, we can't guarantee that 275 // returned values from a remote processor has typed information because 276 // a logic running in the remote processor possibly be in a separate binary 277 // which may not link tensorflow libraries. To deal with this situation, 278 // remote fused graph needs to overwrite the tensor data by a byte array. 279 static Status CopyByteArrayToTensor(const void* src_ptr, const int src_size, 280 Tensor* tensor); 281 282 static std::unordered_set<string> BuildNodeMapFromOpTypes( 283 const GraphDef& graph_def, const std::unordered_set<string>& op_types); 284 285 static std::unordered_set<string> BuildNodeMapFromOpsDefinitions( 286 const GraphDef& graph_def, 287 const IRemoteFusedGraphOpsDefinitions& ops_definitions); 288 289 private: 290 static void EmplaceTensorShapeType(const string& name, const Tensor& tensor, 291 TensorShapeMap* tensor_shape_map); 292 293 static Status ReplaceInputNodeByPlaceHolder(const string& input, 294 const DataType type, 295 const TensorShape& shape, 296 GraphDef* graph_def); 297 298 static ExecutorBuildRegistry* GetExecutorBuildRegistry(); 299 300 static string BuildNodeTypeAttr( 301 const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 302 const int index, const string& executor_name, const string& node_name); 303 304 static string BuildNodeTypeAttr( 305 const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 306 const int index); 307 308 static string BuildNodeTypeAttr( 309 const RemoteFusedGraphExecuteInfo::NodeType node_type); 310 311 TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils); 312 }; 313 } // namespace tensorflow 314 315 #endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_ 316