Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
     17 #define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
     18 
     19 #include <unordered_map>
     20 #include <unordered_set>
     21 
     22 #include "tensorflow/core/framework/graph.pb.h"
     23 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
     24 #include "tensorflow/core/graph/graph.h"
     25 #include "tensorflow/core/graph/graph_constructor.h"
     26 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
     27 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 
     31 namespace tensorflow {
     32 
     33 // RemoteFusedGraphExecuteUtils provides APIs to register and get builder
     34 // functions for IRemoteFusedGraphExecutor.
     35 class RemoteFusedGraphExecuteUtils {
     36  public:
     37   // TODO(satok): Use "_output_data_types" to share a spec with other ops
     38   static constexpr const char* const ATTR_OUTPUT_DATA_TYPES =
     39       "_default_remote_graph_output_data_types";
     40   // TODO(satok): Use "_output_shapes" to share a spec with other ops
     41   static constexpr const char* const ATTR_OUTPUT_SHAPES =
     42       "_default_remote_output_shapes";
     43   static constexpr const char* const
     44       ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO =
     45           "serialized_remote_fused_graph_execute_info";
     46   static constexpr const char* const ATTR_NODE_TYPE =
     47       "_remote_fused_graph_node_type";
     48 
     49   // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp.
     50   static constexpr const char* const
     51       TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
     52           "remote_fused_graph_executor_name";
     53   static constexpr const char* const
     54       TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME =
     55           "remote_fused_graph_node_name";
     56   static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes";
     57   static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS =
     58       "border_inputs";
     59   static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS =
     60       "border_outputs";
     61   static constexpr const char* const TRANSFORM_ARG_FUSED_OP_TYPES =
     62       "fused_op_types";
     63   static constexpr const char* const TRANSFORM_ARG_FUSE_BY_EXECUTOR =
     64       "fuse_by_executor";
     65   static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
     66   static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
     67       "input_shapes";
     68 
     69   using ExecutorBuildFunc = std::function<Status(
     70       std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>;
     71   // Registrar class for IRemoteFusedGraphExecutor.
     72   class ExecutorBuildRegistrar {
     73    public:
     74     ExecutorBuildRegistrar(const string& name, ExecutorBuildFunc func);
     75 
     76    private:
     77     TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBuildRegistrar);
     78   };
     79   using ExecutorBuildRegistry = std::map<string, ExecutorBuildFunc>;
     80 
     81   using TensorShapeType = std::pair<DataType, TensorShape>;
     82   using TensorShapeMap = std::unordered_multimap<string,         // node name
     83                                                  std::pair<int,  // port
     84                                                            TensorShapeType>>;
     85   using ClusterInfo = std::tuple<std::unordered_set<string>,  // node names
     86                                  std::vector<string>,         // border inputs
     87                                  std::vector<string>>;        // border outputs
     88 
     89   // Return registered ExecutorBuildFunc for given name.
     90   static const ExecutorBuildFunc* GetExecutorBuildFunc(const string& name);
     91 
     92   // To determine shapes of output tensors of all nodes, dryrun the graph.
     93   // This function supplies memory allocation information when loading
     94   // the graph. This function is used to verify shape inference and actual
     95   // output shape.
     96   static Status DryRunInference(
     97       const GraphDef& graph_def,
     98       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
     99       const std::vector<string>& output_node_names,
    100       const bool initialize_by_zero,
    101       std::vector<tensorflow::Tensor>* output_tensors);
    102 
    103   // Dry run inference to obtain shapes for all nodes.
    104   // CAVEAT: Do not add or modify output_tensors in output_tensor_info
    105   // otherwise, address map may be broken by re-allocation inside
    106   // std::vector.
    107   static Status DryRunInferenceForAllNode(
    108       const GraphDef& graph_def,
    109       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    110       const bool initialize_by_zero, TensorShapeMap* tensor_shape_map);
    111 
    112   static bool IsInputNode(
    113       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    114       const string& node_name);
    115 
    116   static void ConvertToTensorShapeMap(
    117       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    118       const std::vector<string>& output_node_names,
    119       const std::vector<tensorflow::Tensor>& output_tensors,
    120       TensorShapeMap* tensor_shape_map);
    121 
    122   static Status MakeTensorFromProto(const TensorProto& tensor_proto,
    123                                     Tensor* tensor);
    124 
    125   static bool AddOutputTensorShapeType(const std::vector<DataType>& data_types,
    126                                        const std::vector<TensorShape>& shapes,
    127                                        NodeDef* node_def);
    128 
    129   static Status AddOutputTensorShapeTypeByTensorShapeMap(
    130       const TensorShapeMap& tensor_shape_map, NodeDef* node_def);
    131 
    132   static Status GetOutputTensorShapeType(AttrSlice attrs,
    133                                          std::vector<DataType>* data_types,
    134                                          std::vector<TensorShape>* shapes);
    135 
    136   static bool GetOutputTensorShapeType(const GraphDef& graph_def,
    137                                        const string& name_and_port,
    138                                        DataType* data_type, TensorShape* shape);
    139 
    140   static Status PropagateShapeInference(
    141       const GraphDef& graph_def,
    142       const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    143       Graph* graph, ShapeRefiner* shape_refiner);
    144 
    145   static Status BuildTensorShapeMapFromGraph(const Graph& graph,
    146                                              const ShapeRefiner& shape_refiner,
    147                                              TensorShapeMap* tensor_shape_map);
    148 
    149   static const TensorShapeType* GetTensorShapeType(
    150       const TensorShapeMap& tensor_shape_map, const string& node_name);
    151 
    152   static const TensorShapeType* GetTensorShapeType(
    153       const TensorShapeMap& tensor_shape_map, const string& node_name,
    154       const int port);
    155 
    156   static void BuildRemoteGraphInputsAndOutputsFromProto(
    157       const RemoteFusedGraphExecuteInfo& proto,
    158       std::vector<std::pair<string, Tensor>>* inputs,
    159       std::vector<string>* outputs);
    160 
    161   static Status BuildAndAddTensorShapes(
    162       const std::vector<std::pair<string, Tensor>>& input_tensors,
    163       const bool dry_run_inference, GraphDef* graph_def);
    164 
    165   // Build remote fused graph execute info.
    166   static Status BuildRemoteFusedGraphExecuteInfo(
    167       const string& executor_name, const GraphDef& subgraph_def,
    168       const std::vector<string>& inputs, const std::vector<string>& outputs,
    169       const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
    170       DataTypeVector* input_types, DataTypeVector* output_types);
    171 
    172   // Build remote fused graph execute op node by fusing specified subgraph
    173   // as remote fused graph execute info.
    174   static Status BuildRemoteFusedGraphExecuteOpNode(
    175       const string& node_name, const string& executor_name,
    176       const GraphDef& subgraph_def, const std::vector<string>& inputs,
    177       const std::vector<string>& outputs, const bool require_shape_type,
    178       Graph* graph, Node** created_node);
    179 
    180   // Build Identity node to forward remote graph node output.
    181   static Status BuildIdentityOpNode(const string& node_name,
    182                                     const string& input_node_name,
    183                                     const int input_node_port,
    184                                     const DataType dt, Graph* graph,
    185                                     Node** created_node);
    186 
    187   // Create clusters of given nodes.
    188   static Status ClusterizeNodes(const std::unordered_set<string>& node_names,
    189                                 const GraphDef& graph_def,
    190                                 std::vector<ClusterInfo>* cluster_infos);
    191 
    192   // Build GraphDef of a given cluster.
    193   static Status BuildClusterSubgraphDef(const ClusterInfo& cluster,
    194                                         const GraphDef& graph_def,
    195                                         GraphDef* subgraph_def);
    196 
    197   // Build a cluster by given border.
    198   // CAVEAT: The border must be consistent for one cluster.
    199   static Status BuildClusterByBorder(const std::vector<string>& border_inputs,
    200                                      const std::vector<string>& border_outputs,
    201                                      const GraphDef& graph_def,
    202                                      ClusterInfo* cluster);
    203 
    204   // Fuse one cluster into a newly created RemoteFusedGraphExecuteOp node.
    205   // The subgraph is stored as a graph in RemoteFusedGraphExecuteInfo.
    206   // CAVEAT1: This transform strips unvisited nodes with given outputs.
    207   // CAVEAT2: If you want to use a graph output as a border output,
    208   // that graph output node is replaced by an identity node.  Therefore,
    209   // the number of output of the node must be 1.
    210   static Status FuseCluster(const GraphDef& input_graph_def,
    211                             const std::vector<string>& inputs,
    212                             const std::vector<string>& outputs,
    213                             const string& remote_fused_graph_node_name,
    214                             const ClusterInfo& cluster,
    215                             const string& remote_graph_executor_name,
    216                             const bool require_shape_type,
    217                             GraphDef* output_graph_def);
    218 
    219   // Fuse subgraph of specified nodes.
    220   static Status FuseRemoteGraphByNodeNames(
    221       const GraphDef& input_graph_def, const std::vector<string>& inputs,
    222       const std::vector<string>& outputs,
    223       const string& remote_fused_graph_node_name_prefix,
    224       const std::unordered_set<string>& subgraph_nodes,
    225       const string& remote_fused_graph_executor_name,
    226       const bool require_shape_type, GraphDef* output_graph_def);
    227 
    228   // Fuse subgraph of specified border.
    229   static Status FuseRemoteGraphByBorder(
    230       const GraphDef& input_graph_def, const std::vector<string>& inputs,
    231       const std::vector<string>& outputs,
    232       const string& remote_fused_graph_node_name,
    233       const std::vector<string>& border_inputs,
    234       const std::vector<string>& border_outputs,
    235       const string& remote_graph_executor_name, const bool require_shape_type,
    236       GraphDef* output_graph_def);
    237 
    238   // Fuse subgraph of specified op types.
    239   static Status FuseRemoteGraphByOpTypes(
    240       const GraphDef& input_graph_def, const std::vector<string>& inputs,
    241       const std::vector<string>& outputs,
    242       const string& remote_fused_graph_node_name_prefix,
    243       const std::unordered_set<string>& fused_op_types,
    244       const string& remote_fused_graph_executor_name,
    245       const bool require_shape_type, GraphDef* output_graph_def);
    246 
    247   // Place arguments to fuse remote graph.
    248   static Status PlaceRemoteGraphArguments(
    249       const std::vector<string>& inputs, const std::vector<string>& outputs,
    250       const std::unordered_set<string>& fused_node_names,
    251       const std::vector<string>& border_inputs,
    252       const std::vector<string>& border_outputs,
    253       const std::unordered_set<string>& fused_op_types,
    254       const string& remote_fused_graph_node_name,
    255       const string& remote_graph_executor_name, GraphDef* graph_def);
    256 
    257   // Fuse remote graph by placed arguments.
    258   static Status FuseRemoteGraphByPlacedArguments(
    259       const GraphDef& input_graph_def,
    260       const std::vector<std::pair<string, Tensor>>& input_tensors,
    261       GraphDef* output_graph_def);
    262 
    263   static Status FuseRemoteGraphByExecutor(const GraphDef& input_graph_def,
    264                                           const std::vector<string>& inputs,
    265                                           const std::vector<string>& outputs,
    266                                           const string& executor_name,
    267                                           GraphDef* output_graph_def);
    268 
    269   static bool IsFuseReady(
    270       const GraphDef& input_graph_def,
    271       const std::vector<std::pair<string, Tensor>>& input_tensors);
    272 
    273   // Copy a byte array to a tensor data.  Though tensor data must be
    274   // updated with typed information in general, we can't guarantee that
    275   // returned values from a remote processor has typed information because
    276   // a logic running in the remote processor possibly be in a separate binary
    277   // which may not link tensorflow libraries.  To deal with this situation,
    278   // remote fused graph needs to overwrite the tensor data by a byte array.
    279   static Status CopyByteArrayToTensor(const void* src_ptr, const int src_size,
    280                                       Tensor* tensor);
    281 
    282   static std::unordered_set<string> BuildNodeMapFromOpTypes(
    283       const GraphDef& graph_def, const std::unordered_set<string>& op_types);
    284 
    285   static std::unordered_set<string> BuildNodeMapFromOpsDefinitions(
    286       const GraphDef& graph_def,
    287       const IRemoteFusedGraphOpsDefinitions& ops_definitions);
    288 
    289  private:
    290   static void EmplaceTensorShapeType(const string& name, const Tensor& tensor,
    291                                      TensorShapeMap* tensor_shape_map);
    292 
    293   static Status ReplaceInputNodeByPlaceHolder(const string& input,
    294                                               const DataType type,
    295                                               const TensorShape& shape,
    296                                               GraphDef* graph_def);
    297 
    298   static ExecutorBuildRegistry* GetExecutorBuildRegistry();
    299 
    300   static string BuildNodeTypeAttr(
    301       const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
    302       const int index, const string& executor_name, const string& node_name);
    303 
    304   static string BuildNodeTypeAttr(
    305       const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
    306       const int index);
    307 
    308   static string BuildNodeTypeAttr(
    309       const RemoteFusedGraphExecuteInfo::NodeType node_type);
    310 
    311   TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils);
    312 };
    313 }  // namespace tensorflow
    314 
    315 #endif  // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
    316