1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h" 17 18 #include "tensorflow/cc/framework/scope.h" 19 #include "tensorflow/cc/ops/const_op.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/graph/node_builder.h" 22 #include "tensorflow/core/platform/logging.h" 23 namespace tensorflow { 24 25 // function alias 26 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap = 27 &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap; 28 29 /* static */ std::priority_queue<std::tuple<float, int, string>> 30 GraphTransferUtils::GetTopNFloatResults(const float* const data, 31 const string* const labels, 32 const int element_count) { 33 CHECK(data != nullptr); 34 CHECK(labels != nullptr); 35 std::priority_queue<std::tuple<float, int, string>> queue; 36 for (int i = 0; i < element_count; ++i) { 37 queue.emplace(data[i], i, labels[i]); 38 } 39 return queue; 40 } 41 42 /* static */ void GraphTransferUtils::DumpTopNFloatResults( 43 const float* const data, const string* const labels, 44 const int element_count, const int top_n) { 45 std::priority_queue<std::tuple<float, int, string>> queue = 46 GetTopNFloatResults(data, labels, element_count); 47 LOG(INFO) << "=== Dump ranking ==="; 48 for (int i = 0; i < top_n; ++i) { 49 const std::tuple<float, int, string>& entry = queue.top(); 50 LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry) 51 << ", " << std::get<0>(entry); 52 queue.pop(); 53 } 54 } 55 56 /* static */ RemoteFusedGraphExecuteInfo 57 GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( 58 const GraphDef& graph_def, 59 const std::vector<std::pair<string, Tensor>>& inputs, 60 const std::vector<string>& outputs, 61 const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map) { 62 RemoteFusedGraphExecuteInfo execute_info; 63 execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor"); 64 65 // copy graph 66 *execute_info.mutable_remote_graph() = graph_def; 67 68 for (const std::pair<string, Tensor>& input : inputs) { 69 execute_info.add_graph_input_node_name(input.first); 70 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type = 71 *execute_info.add_default_graph_input_tensor_shape(); 72 tensor_shape_type.set_dtype(input.second.dtype()); 73 TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape(); 74 for (const int64 dim : input.second.shape().dim_sizes()) { 75 tensor_shape_proto.add_dim()->set_size(dim); 76 } 77 } 78 79 for (const string& output_name : outputs) { 80 const std::pair<DataType, TensorShape>* tensor_shape_type = 81 RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 82 output_name); 83 CHECK_NOTNULL(tensor_shape_type); 84 execute_info.add_graph_output_node_name(output_name); 85 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type_proto = 86 *execute_info.add_default_graph_output_tensor_shape(); 87 tensor_shape_type_proto.set_dtype(tensor_shape_type->first); 88 TensorShapeProto& tensor_shape_proto = 89 *tensor_shape_type_proto.mutable_shape(); 90 for (const int64 dim : tensor_shape_type->second.dim_sizes()) { 91 tensor_shape_proto.add_dim()->set_size(dim); 92 } 93 } 94 95 return execute_info; 96 } 97 98 /* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef( 99 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 100 const string& remote_graph_execute_name, 101 const std::vector<std::pair<string, Tensor>>& inputs, 102 const std::vector<string>& outputs, GraphDef* original_def) { 103 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; 104 Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( 105 *original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map); 106 for (NodeDef& node_def : *original_def->mutable_node()) { 107 TF_CHECK_OK( 108 AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def)); 109 } 110 CHECK(status.ok()); 111 112 Scope root = Scope::NewRootScope(); 113 std::vector<Output> output_list; 114 DataTypeVector input_types; 115 for (const std::pair<string, Tensor>& input_node_info : inputs) { 116 const Scope& scope = root.WithOpName(input_node_info.first); 117 Node* ret; 118 const auto unique_name = scope.GetUniqueNameForOp("Placeholder"); 119 auto builder = NodeBuilder(unique_name, "Placeholder") 120 .Attr("dtype", input_node_info.second.dtype()) 121 .Attr("shape", input_node_info.second.shape()); 122 scope.UpdateBuilder(&builder); 123 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); 124 TF_CHECK_OK(scope.status()); 125 output_list.emplace_back(Output(ret, 0)); 126 input_types.push_back(input_node_info.second.dtype()); 127 } 128 129 const RemoteFusedGraphExecuteInfo execute_info = 130 BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs, 131 tensor_shape_map); 132 133 DataTypeVector output_types; 134 // Sanity-check to confirm all output data types are same. 135 for (const string& output_node_name : outputs) { 136 const std::pair<DataType, TensorShape>* tst = 137 RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 138 output_node_name); 139 CHECK_NE(tst, nullptr); 140 output_types.push_back(tst->first); 141 } 142 143 const Scope& scope = root.WithOpName(remote_graph_execute_name); 144 CHECK(scope.ok()); 145 auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list)); 146 Node* node; 147 const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute"); 148 149 auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute") 150 .Input(node_out_list) 151 .Attr("Tinputs", input_types) 152 .Attr("Toutputs", output_types) 153 .Attr("serialized_remote_fused_graph_execute_info", 154 StringPiece(execute_info.SerializeAsString())); 155 CHECK(scope.ok()); 156 scope.UpdateBuilder(&builder); 157 scope.UpdateStatus(builder.Finalize(scope.graph(), &node)); 158 CHECK(scope.ok()) << scope.status(); 159 160 GraphDef fusedGraphDef; 161 TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef)); 162 return fusedGraphDef; 163 } 164 165 } // namespace tensorflow 166