Home | History | Annotate | Download | only in hexagon
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
     17 
     18 #include "tensorflow/cc/framework/scope.h"
     19 #include "tensorflow/cc/ops/const_op.h"
     20 #include "tensorflow/core/framework/tensor_shape.pb.h"
     21 #include "tensorflow/core/graph/node_builder.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 namespace tensorflow {
     24 
     25 // function alias
     26 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
     27     &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
     28 
     29 /* static */ std::priority_queue<std::tuple<float, int, string>>
     30 GraphTransferUtils::GetTopNFloatResults(const float* const data,
     31                                         const string* const labels,
     32                                         const int element_count) {
     33   CHECK(data != nullptr);
     34   CHECK(labels != nullptr);
     35   std::priority_queue<std::tuple<float, int, string>> queue;
     36   for (int i = 0; i < element_count; ++i) {
     37     queue.emplace(data[i], i, labels[i]);
     38   }
     39   return queue;
     40 }
     41 
     42 /* static */ void GraphTransferUtils::DumpTopNFloatResults(
     43     const float* const data, const string* const labels,
     44     const int element_count, const int top_n) {
     45   std::priority_queue<std::tuple<float, int, string>> queue =
     46       GetTopNFloatResults(data, labels, element_count);
     47   LOG(INFO) << "=== Dump ranking ===";
     48   for (int i = 0; i < top_n; ++i) {
     49     const std::tuple<float, int, string>& entry = queue.top();
     50     LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
     51               << ", " << std::get<0>(entry);
     52     queue.pop();
     53   }
     54 }
     55 
     56 /* static */ RemoteFusedGraphExecuteInfo
     57 GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
     58     const GraphDef& graph_def,
     59     const std::vector<std::pair<string, Tensor>>& inputs,
     60     const std::vector<string>& outputs,
     61     const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map) {
     62   RemoteFusedGraphExecuteInfo execute_info;
     63   execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
     64 
     65   // copy graph
     66   *execute_info.mutable_remote_graph() = graph_def;
     67 
     68   for (const std::pair<string, Tensor>& input : inputs) {
     69     execute_info.add_graph_input_node_name(input.first);
     70     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
     71         *execute_info.add_default_graph_input_tensor_shape();
     72     tensor_shape_type.set_dtype(input.second.dtype());
     73     TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
     74     for (const int64 dim : input.second.shape().dim_sizes()) {
     75       tensor_shape_proto.add_dim()->set_size(dim);
     76     }
     77   }
     78 
     79   for (const string& output_name : outputs) {
     80     const std::pair<DataType, TensorShape>* tensor_shape_type =
     81         RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
     82                                                          output_name);
     83     CHECK_NOTNULL(tensor_shape_type);
     84     execute_info.add_graph_output_node_name(output_name);
     85     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type_proto =
     86         *execute_info.add_default_graph_output_tensor_shape();
     87     tensor_shape_type_proto.set_dtype(tensor_shape_type->first);
     88     TensorShapeProto& tensor_shape_proto =
     89         *tensor_shape_type_proto.mutable_shape();
     90     for (const int64 dim : tensor_shape_type->second.dim_sizes()) {
     91       tensor_shape_proto.add_dim()->set_size(dim);
     92     }
     93   }
     94 
     95   return execute_info;
     96 }
     97 
     98 /* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
     99     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    100     const string& remote_graph_execute_name,
    101     const std::vector<std::pair<string, Tensor>>& inputs,
    102     const std::vector<string>& outputs, GraphDef* original_def) {
    103   RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
    104   Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
    105       *original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map);
    106   for (NodeDef& node_def : *original_def->mutable_node()) {
    107     TF_CHECK_OK(
    108         AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
    109   }
    110   CHECK(status.ok());
    111 
    112   Scope root = Scope::NewRootScope();
    113   std::vector<Output> output_list;
    114   DataTypeVector input_types;
    115   for (const std::pair<string, Tensor>& input_node_info : inputs) {
    116     const Scope& scope = root.WithOpName(input_node_info.first);
    117     Node* ret;
    118     const auto unique_name = scope.GetUniqueNameForOp("Placeholder");
    119     auto builder = NodeBuilder(unique_name, "Placeholder")
    120                        .Attr("dtype", input_node_info.second.dtype())
    121                        .Attr("shape", input_node_info.second.shape());
    122     scope.UpdateBuilder(&builder);
    123     scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
    124     TF_CHECK_OK(scope.status());
    125     output_list.emplace_back(Output(ret, 0));
    126     input_types.push_back(input_node_info.second.dtype());
    127   }
    128 
    129   const RemoteFusedGraphExecuteInfo execute_info =
    130       BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
    131                                        tensor_shape_map);
    132 
    133   DataTypeVector output_types;
    134   // Sanity-check to confirm all output data types are same.
    135   for (const string& output_node_name : outputs) {
    136     const std::pair<DataType, TensorShape>* tst =
    137         RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
    138                                                          output_node_name);
    139     CHECK_NE(tst, nullptr);
    140     output_types.push_back(tst->first);
    141   }
    142 
    143   const Scope& scope = root.WithOpName(remote_graph_execute_name);
    144   CHECK(scope.ok());
    145   auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
    146   Node* node;
    147   const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
    148 
    149   auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
    150                      .Input(node_out_list)
    151                      .Attr("Tinputs", input_types)
    152                      .Attr("Toutputs", output_types)
    153                      .Attr("serialized_remote_fused_graph_execute_info",
    154                            StringPiece(execute_info.SerializeAsString()));
    155   CHECK(scope.ok());
    156   scope.UpdateBuilder(&builder);
    157   scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
    158   CHECK(scope.ok()) << scope.status();
    159 
    160   GraphDef fusedGraphDef;
    161   TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));
    162   return fusedGraphDef;
    163 }
    164 
    165 }  // namespace tensorflow
    166