1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Wraps the hexagon rewriter in a transform so it can be used as part of the 17 // graph transform tool. 18 // A usage example, based on inception v3 model: 19 /* 20 bazel build tensorflow/tools/graph_transforms:transform_graph 21 22 23 // Specify remote graph by node names 24 bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 25 --in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \ 26 --out_graph=\ 27 /tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \ 28 --inputs='Mul' \ 29 --outputs='softmax' \ 30 --transforms='\ 31 fuse_remote_graph( 32 input_types="float" \ 33 input_shapes="1,299,299,3" \ 34 fused_nodes="NodeA,NodeB,NodeC", 35 remote_fused_graph_executor_name="executor" \ 36 remote_fused_graph_node_name="node_name" \ 37 )' 38 39 // Specify remote graph by border inputs and outputs 40 bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 41 --in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \ 42 --out_graph=\ 43 /tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \ 44 --inputs='Mul' \ 45 --outputs='softmax' \ 46 --transforms='\ 47 fuse_remote_graph( 48 input_types="float" \ 49 input_shapes="1,299,299,3" \ 50 border_inputs="NodeA:0,NodeB:0" \ 51 border_outputs="NodeC" \ 52 remote_fused_graph_executor_name="executor" \ 53 remote_fused_graph_node_name="node_name" \ 54 )' 55 */ 56 57 #include <unordered_set> 58 59 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 60 #include "tensorflow/tools/graph_transforms/transform_utils.h" 61 62 namespace tensorflow { 63 namespace graph_transforms { 64 65 static Status ParseArguments(const TransformFuncContext& context, 66 string* input_types_str, string* input_shapes_str, 67 string* fused_nodes_str, string* border_inputs_str, 68 string* border_outputs_str, 69 string* fused_op_types_str, bool* fuse_by_executor, 70 string* remote_fused_graph_node_name, 71 string* remote_graph_executor_name) { 72 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 73 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, "", 74 input_types_str)); 75 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 76 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, "", 77 input_shapes_str)); 78 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 79 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, "", 80 fused_nodes_str)); 81 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 82 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, "", 83 border_inputs_str)); 84 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 85 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, "", 86 border_outputs_str)); 87 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 88 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, "", 89 fused_op_types_str)); 90 TF_RETURN_IF_ERROR(context.GetOneBoolParameter( 91 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, false, 92 fuse_by_executor)); 93 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 94 RemoteFusedGraphExecuteUtils:: 95 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, 96 "", remote_graph_executor_name)); 97 TF_RETURN_IF_ERROR(context.GetOneStringParameter( 98 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME, 99 "", remote_fused_graph_node_name)); 100 101 CHECK(!remote_graph_executor_name->empty()); 102 return Status::OK(); 103 } 104 105 static Status PlaceShapeType(const std::vector<string>& inputs, 106 const std::vector<string>& outputs, 107 const string& input_types_str, 108 const string& input_shapes_str, 109 GraphDef* mutable_input_graph_def) { 110 const std::vector<string> input_types_strs = 111 str_util::Split(input_types_str, ","); 112 const std::vector<string> input_shapes_strs = 113 str_util::Split(input_shapes_str, ":"); 114 CHECK_EQ(inputs.size(), input_types_strs.size()); 115 CHECK_EQ(inputs.size(), input_shapes_strs.size()); 116 std::vector<std::pair<string, Tensor>> input_tensors; 117 for (size_t i = 0; i < inputs.size(); ++i) { 118 const string& name = inputs.at(i); 119 std::vector<int64> dims; 120 CHECK(str_util::SplitAndParseAsInts(input_shapes_strs.at(i), ',', &dims)); 121 DataType data_type; 122 CHECK(DataTypeFromString(input_types_strs.at(i), &data_type)) 123 << "\"" << input_types_strs.at(i) << "\" was an invalid type"; 124 input_tensors.emplace_back( 125 std::make_pair(name, Tensor(data_type, TensorShape(dims)))); 126 } 127 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 128 input_tensors, /*dry_run_inference=*/true, mutable_input_graph_def)); 129 return Status::OK(); 130 } 131 132 Status FuseRemoteGraph(const GraphDef& input_graph_def, 133 const TransformFuncContext& context, 134 GraphDef* output_graph_def) { 135 GraphDef mutable_input_graph_def = input_graph_def; 136 137 const std::vector<string>& inputs = context.input_names; 138 const std::vector<string>& outputs = context.output_names; 139 140 string input_types_str; 141 string input_shapes_str; 142 string fused_nodes_str; 143 string border_inputs_str; 144 string border_outputs_str; 145 string fused_op_types_str; 146 bool fuse_by_executor = false; 147 string remote_fused_graph_node_name; 148 string remote_graph_executor_name; 149 TF_RETURN_IF_ERROR(ParseArguments( 150 context, &input_types_str, &input_shapes_str, &fused_nodes_str, 151 &border_inputs_str, &border_outputs_str, &fused_op_types_str, 152 &fuse_by_executor, &remote_fused_graph_node_name, 153 &remote_graph_executor_name)); 154 155 if (!input_types_str.empty()) { 156 TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str, 157 input_shapes_str, 158 &mutable_input_graph_def)); 159 } 160 161 const bool require_shape_type = !input_types_str.empty(); 162 if (!fused_nodes_str.empty()) { 163 const std::vector<string> fused_node_name_vector = 164 str_util::Split(fused_nodes_str, ","); 165 const std::unordered_set<string> fused_node_names( 166 fused_node_name_vector.begin(), fused_node_name_vector.end()); 167 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames( 168 mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, 169 fused_node_names, remote_graph_executor_name, require_shape_type, 170 output_graph_def)); 171 } else if (!border_inputs_str.empty() && !border_outputs_str.empty()) { 172 const std::vector<string> border_inputs = 173 str_util::Split(border_inputs_str, ","); 174 const std::vector<string> border_outputs = 175 str_util::Split(border_outputs_str, ","); 176 for (size_t i = 0; i < border_inputs.size(); ++i) { 177 VLOG(2) << "Border Input(" << i << "): " << border_inputs.at(i); 178 } 179 for (size_t i = 0; i < border_outputs.size(); ++i) { 180 VLOG(2) << "Border Output(" << i << "): " << border_outputs.at(i); 181 } 182 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder( 183 mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, 184 border_inputs, border_outputs, remote_graph_executor_name, 185 require_shape_type, output_graph_def)); 186 } else if (!fused_op_types_str.empty()) { 187 const std::vector<string> fused_op_type_vector = 188 str_util::Split(fused_op_types_str, ","); 189 const std::unordered_set<string> fused_op_types( 190 fused_op_type_vector.begin(), fused_op_type_vector.end()); 191 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes( 192 mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, 193 fused_op_types, remote_graph_executor_name, require_shape_type, 194 output_graph_def)); 195 } else if (fuse_by_executor) { 196 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor( 197 mutable_input_graph_def, inputs, outputs, remote_graph_executor_name, 198 output_graph_def)); 199 } else { 200 LOG(FATAL) << "Fuse targets are not specified."; 201 } 202 203 return Status::OK(); 204 } 205 206 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def, 207 const TransformFuncContext& context, 208 GraphDef* output_graph_def) { 209 *output_graph_def = input_graph_def; 210 211 const std::vector<string>& inputs = context.input_names; 212 const std::vector<string>& outputs = context.output_names; 213 214 string input_types_str; 215 string input_shapes_str; 216 string fused_nodes_str; 217 string border_inputs_str; 218 string border_outputs_str; 219 string fused_op_types_str; 220 bool fuse_by_executor = false; 221 string remote_fused_graph_node_name; 222 string remote_graph_executor_name; 223 TF_RETURN_IF_ERROR(ParseArguments( 224 context, &input_types_str, &input_shapes_str, &fused_nodes_str, 225 &border_inputs_str, &border_outputs_str, &fused_op_types_str, 226 &fuse_by_executor, &remote_fused_graph_node_name, 227 &remote_graph_executor_name)); 228 229 if (!input_types_str.empty()) { 230 TF_RETURN_IF_ERROR(PlaceShapeType(inputs, outputs, input_types_str, 231 input_shapes_str, output_graph_def)); 232 } 233 234 const std::vector<string> fused_node_name_vector = 235 str_util::Split(fused_nodes_str, ","); 236 const std::unordered_set<string> fused_node_names( 237 fused_node_name_vector.begin(), fused_node_name_vector.end()); 238 const std::vector<string> border_inputs = 239 str_util::Split(border_inputs_str, ","); 240 const std::vector<string> border_outputs = 241 str_util::Split(border_outputs_str, ","); 242 const std::vector<string> fused_op_type_vector = 243 str_util::Split(fused_op_types_str, ","); 244 const std::unordered_set<string> fused_op_types(fused_op_type_vector.begin(), 245 fused_op_type_vector.end()); 246 247 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments( 248 inputs, outputs, fused_node_names, border_inputs, border_outputs, 249 fused_op_types, remote_fused_graph_node_name, remote_graph_executor_name, 250 output_graph_def)); 251 252 return Status::OK(); 253 } 254 255 REGISTER_GRAPH_TRANSFORM("fuse_remote_graph", FuseRemoteGraph); 256 257 REGISTER_GRAPH_TRANSFORM("place_remote_graph_arguments", 258 PlaceRemoteGraphArguments); 259 260 } // namespace graph_transforms 261 } // namespace tensorflow 262