Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Wraps the hexagon rewriter in a transform so it can be used as part of the
     17 // graph transform tool.
     18 // A usage example, based on inception v3 model:
     19 /*
     20 bazel build tensorflow/tools/graph_transforms:transform_graph
     21 
     22 
     23 // Specify remote graph by node names
     24 bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
     25 --in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
     26 --out_graph=\
     27 /tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
     28 --inputs='Mul' \
     29 --outputs='softmax' \
     30 --transforms='\
     31 fuse_remote_graph(
     32 input_types="float" \
     33 input_shapes="1,299,299,3" \
     34 fused_nodes="NodeA,NodeB,NodeC",
     35 remote_fused_graph_executor_name="executor" \
     36 remote_fused_graph_node_name="node_name" \
     37 )'
     38 
     39 // Specify remote graph by border inputs and outputs
     40 bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
     41 --in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
     42 --out_graph=\
     43 /tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
     44 --inputs='Mul' \
     45 --outputs='softmax' \
     46 --transforms='\
     47 fuse_remote_graph(
     48 input_types="float" \
     49 input_shapes="1,299,299,3" \
     50 border_inputs="NodeA:0,NodeB:0" \
     51 border_outputs="NodeC" \
     52 remote_fused_graph_executor_name="executor" \
     53 remote_fused_graph_node_name="node_name" \
     54 )'
     55 */
     56 
     57 #include <unordered_set>
     58 
     59 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     60 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     61 
     62 namespace tensorflow {
     63 namespace graph_transforms {
     64 
     65 static Status ParseArguments(const TransformFuncContext& context,
     66                              string* input_types_str, string* input_shapes_str,
     67                              string* fused_nodes_str, string* border_inputs_str,
     68                              string* border_outputs_str,
     69                              string* fused_op_types_str, bool* fuse_by_executor,
     70                              string* remote_fused_graph_node_name,
     71                              string* remote_graph_executor_name) {
     72   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     73       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, "",
     74       input_types_str));
     75   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     76       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, "",
     77       input_shapes_str));
     78   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     79       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, "",
     80       fused_nodes_str));
     81   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     82       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, "",
     83       border_inputs_str));
     84   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     85       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, "",
     86       border_outputs_str));
     87   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     88       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "",
     89       fused_op_types_str));
     90   TF_RETURN_IF_ERROR(context.GetOneBoolParameter(
     91       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false,
     92       fuse_by_executor));
     93   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     94       RemoteFusedGraphExecuteUtils::
     95           TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
     96       "", remote_graph_executor_name));
     97   TF_RETURN_IF_ERROR(context.GetOneStringParameter(
     98       RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
     99       "", remote_fused_graph_node_name));
    100 
    101   CHECK(!remote_graph_executor_name->empty());
    102   return Status::OK();
    103 }
    104 
    105 static Status PlaceShapeType(const std::vector<string>& inputs,
    106                              const std::vector<string>& outputs,
    107                              const string& input_types_str,
    108                              const string& input_shapes_str,
    109                              GraphDef* mutable_input_graph_def) {
    110   const std::vector<string> input_types_strs =
    111       str_util::Split(input_types_str, ",");
    112   const std::vector<string> input_shapes_strs =
    113       str_util::Split(input_shapes_str, ":");
    114   CHECK_EQ(inputs.size(), input_types_strs.size());
    115   CHECK_EQ(inputs.size(), input_shapes_strs.size());
    116   std::vector<std::pair<string, Tensor>> input_tensors;
    117   for (size_t i = 0; i < inputs.size(); ++i) {
    118     const string& name = inputs.at(i);
    119     std::vector<int64> dims;
    120     CHECK(str_util::SplitAndParseAsInts(input_shapes_strs.at(i), ',', &dims));
    121     DataType data_type;
    122     CHECK(DataTypeFromString(input_types_strs.at(i), &data_type))
    123         << "\"" << input_types_strs.at(i) << "\" was an invalid type";
    124     input_tensors.emplace_back(
    125         std::make_pair(name, Tensor(data_type, TensorShape(dims))));
    126   }
    127   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
    128       input_tensors, /*dry_run_inference=*/true, mutable_input_graph_def));
    129   return Status::OK();
    130 }
    131 
    132 Status FuseRemoteGraph(const GraphDef& input_graph_def,
    133                        const TransformFuncContext& context,
    134                        GraphDef* output_graph_def) {
    135   GraphDef mutable_input_graph_def = input_graph_def;
    136 
    137   const std::vector<string>& inputs = context.input_names;
    138   const std::vector<string>& outputs = context.output_names;
    139 
    140   string input_types_str;
    141   string input_shapes_str;
    142   string fused_nodes_str;
    143   string border_inputs_str;
    144   string border_outputs_str;
    145   string fused_op_types_str;
    146   bool fuse_by_executor = false;
    147   string remote_fused_graph_node_name;
    148   string remote_graph_executor_name;
    149   TF_RETURN_IF_ERROR(ParseArguments(
    150       context, &input_types_str, &input_shapes_str, &fused_nodes_str,
    151       &border_inputs_str, &border_outputs_str, &fused_op_types_str,
    152       &fuse_by_executor, &remote_fused_graph_node_name,
    153       &remote_graph_executor_name));
    154 
    155   if (!input_types_str.empty()) {
    156     TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
    157                                       input_shapes_str,
    158                                       &mutable_input_graph_def));
    159   }
    160 
    161   const bool require_shape_type = !input_types_str.empty();
    162   if (!fused_nodes_str.empty()) {
    163     const std::vector<string> fused_node_name_vector =
    164         str_util::Split(fused_nodes_str, ",");
    165     const std::unordered_set<string> fused_node_names(
    166         fused_node_name_vector.begin(), fused_node_name_vector.end());
    167     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
    168         mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
    169         fused_node_names, remote_graph_executor_name, require_shape_type,
    170         output_graph_def));
    171   } else if (!border_inputs_str.empty() && !border_outputs_str.empty()) {
    172     const std::vector<string> border_inputs =
    173         str_util::Split(border_inputs_str, ",");
    174     const std::vector<string> border_outputs =
    175         str_util::Split(border_outputs_str, ",");
    176     for (size_t i = 0; i < border_inputs.size(); ++i) {
    177       VLOG(2) << "Border Input(" << i << "): " << border_inputs.at(i);
    178     }
    179     for (size_t i = 0; i < border_outputs.size(); ++i) {
    180       VLOG(2) << "Border Output(" << i << "): " << border_outputs.at(i);
    181     }
    182     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
    183         mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
    184         border_inputs, border_outputs, remote_graph_executor_name,
    185         require_shape_type, output_graph_def));
    186   } else if (!fused_op_types_str.empty()) {
    187     const std::vector<string> fused_op_type_vector =
    188         str_util::Split(fused_op_types_str, ",");
    189     const std::unordered_set<string> fused_op_types(
    190         fused_op_type_vector.begin(), fused_op_type_vector.end());
    191     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
    192         mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
    193         fused_op_types, remote_graph_executor_name, require_shape_type,
    194         output_graph_def));
    195   } else if (fuse_by_executor) {
    196     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
    197         mutable_input_graph_def, inputs, outputs, remote_graph_executor_name,
    198         output_graph_def));
    199   } else {
    200     LOG(FATAL) << "Fuse targets are not specified.";
    201   }
    202 
    203   return Status::OK();
    204 }
    205 
    206 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
    207                                  const TransformFuncContext& context,
    208                                  GraphDef* output_graph_def) {
    209   *output_graph_def = input_graph_def;
    210 
    211   const std::vector<string>& inputs = context.input_names;
    212   const std::vector<string>& outputs = context.output_names;
    213 
    214   string input_types_str;
    215   string input_shapes_str;
    216   string fused_nodes_str;
    217   string border_inputs_str;
    218   string border_outputs_str;
    219   string fused_op_types_str;
    220   bool fuse_by_executor = false;
    221   string remote_fused_graph_node_name;
    222   string remote_graph_executor_name;
    223   TF_RETURN_IF_ERROR(ParseArguments(
    224       context, &input_types_str, &input_shapes_str, &fused_nodes_str,
    225       &border_inputs_str, &border_outputs_str, &fused_op_types_str,
    226       &fuse_by_executor, &remote_fused_graph_node_name,
    227       &remote_graph_executor_name));
    228 
    229   if (!input_types_str.empty()) {
    230     TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str,
    231                                       input_shapes_str, output_graph_def));
    232   }
    233 
    234   const std::vector<string> fused_node_name_vector =
    235       str_util::Split(fused_nodes_str, ",");
    236   const std::unordered_set<string> fused_node_names(
    237       fused_node_name_vector.begin(), fused_node_name_vector.end());
    238   const std::vector<string> border_inputs =
    239       str_util::Split(border_inputs_str, ",");
    240   const std::vector<string> border_outputs =
    241       str_util::Split(border_outputs_str, ",");
    242   const std::vector<string> fused_op_type_vector =
    243       str_util::Split(fused_op_types_str, ",");
    244   const std::unordered_set<string> fused_op_types(fused_op_type_vector.begin(),
    245                                                   fused_op_type_vector.end());
    246 
    247   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
    248       inputs, outputs, fused_node_names, border_inputs, border_outputs,
    249       fused_op_types, remote_fused_graph_node_name, remote_graph_executor_name,
    250       output_graph_def));
    251 
    252   return Status::OK();
    253 }
    254 
    255 REGISTER_GRAPH_TRANSFORM("fuse_remote_graph", FuseRemoteGraph);
    256 
    257 REGISTER_GRAPH_TRANSFORM("place_remote_graph_arguments",
    258                          PlaceRemoteGraphArguments);
    259 
    260 }  // namespace graph_transforms
    261 }  // namespace tensorflow
    262