Home | History | Annotate | Download | only in hexagon
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
     17 
     18 #include <algorithm>
     19 #include <cinttypes>
     20 
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/graph/algorithm.h"
     23 #include "tensorflow/core/graph/graph_constructor.h"
     24 #include "tensorflow/core/graph/node_builder.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/types.h"
     27 #include "tensorflow/core/public/session.h"
     28 #include "tensorflow/core/public/session_options.h"
     29 #include "tensorflow/core/util/tensor_slice_writer.h"
     30 
     31 namespace tensorflow {
     32 
     33 // function alias
     34 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
     35     &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
     36 
     37 constexpr bool DBG_DUMP_VERIFICATION_STRING = false;
     38 constexpr bool DBG_DUMP_PARAMS = false;
     39 
     40 const char RESHAPE_NODE_TYPE_STRING[] = "Reshape";
     41 const char SOURCE_NODE_NAME[] = "_SOURCE";
     42 const char SINK_NODE_NAME[] = "_SINK";
     43 const char INPUTS_NODE_PREFIX[] = "inputs_for_";
     44 const char OUTPUTS_NODE_PREFIX[] = "outputs_for_";
     45 const char DATA_NODE_PREFIX[] = "data_for_op_";
     46 const char CONST_SHAPE_PREFIX[] = "const_shape_";
     47 const char CONST_VAL_PREFIX[] = "const_val_";
     48 const char CONST_TENSOR_PREFIX[] = "const_tensor_";
     49 const char PADDING_ATTR_NAME[] = "padding";
     50 const char STRIDES_ATTR_NAME[] = "strides";
     51 const char KEEP_DIMS_ATTR_NAME[] = "keep_dims";
     52 const char KSIZE_ATTR_NAME[] = "ksize";
     53 const char NULL_OUTPUT_NAME[] = "NULL";
     54 const char AGGREGATED_INPUT_NODE_NAME[] = "graph_transfer_aggregated_input";
     55 const int PADDING_NA_ID = 0;  // VALID = 1, SAME = 2
     56 
     57 // This is a temporary workaround to support android build
     58 // where std::string is not supported even with c++11 option.
     59 template <typename T>
     60 static string ToString(T val) {
     61   std::stringstream stream;
     62   stream << val;
     63   return stream.str();
     64 }
     65 
     66 static Node* FindMutableNodeByName(const string& name, Graph* graph) {
     67   const TensorId tid = ParseTensorName(name);
     68   for (Node* node : graph->nodes()) {
     69     if (node != nullptr && node->name() == tid.first) {
     70       return node;
     71     }
     72   }
     73   return nullptr;
     74 }
     75 
     76 /**
     77  * graph loading functions
     78  * - LoadGraphFromProto
     79  * - LoadGraphFromProptoFile
     80  * These functions read a graph definition and store parameters
     81  * of node to transfer the graph to SOC.
     82  */
     83 Status GraphTransferer::LoadGraphFromProto(
     84     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
     85     const GraphDef& graph_def,
     86     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
     87     const std::vector<string>& output_node_names,
     88     const bool shape_inference_for_unknown_shape) {
     89   Graph graph(OpRegistry::Global());
     90   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
     91   Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner);
     92   if (!status.ok()) {
     93     return status;
     94   }
     95 
     96   if (shape_inference_for_unknown_shape) {
     97     status = RemoteFusedGraphExecuteUtils::PropagateShapeInference(
     98         graph_def, input_node_info_list, &graph, &shape_refiner);
     99     if (!status.ok()) {
    100       return status;
    101     }
    102   }
    103 
    104   TF_RETURN_IF_ERROR(TransformGraphToAddAggregatedInputNode(
    105       input_node_info_list, &graph, &shape_refiner));
    106 
    107   std::unordered_multimap<string, const Node*> op_name_to_node_multimap(
    108       graph.num_nodes());
    109   for (const Node* const node : graph.nodes()) {
    110     if (node == nullptr) {
    111       continue;
    112     }
    113     CacheNode(*node);
    114   }
    115 
    116   for (const Node* const node : graph.nodes()) {
    117     if (node == nullptr) {
    118       continue;
    119     }
    120     VLOG(1) << "<Node> " << node->name();
    121     for (const Node* const input_node : node->in_nodes()) {
    122       const string& name = input_node->name();
    123       op_name_to_node_multimap.emplace(name, node);
    124       VLOG(1) << "Add dependency: " << name << " -> " << node->name();
    125     }
    126   }
    127 
    128   for (const Node* const node : graph.nodes()) {
    129     if (node == nullptr) {
    130       continue;
    131     }
    132     status = RegisterNodeIfAllInputsAreCached(
    133         ops_definitions, shape_refiner, *node, false, input_node_info_list,
    134         output_node_names);
    135     if (!status.ok()) {
    136       LOG(ERROR) << "Failed to transfer graph " << status;
    137       return status;
    138     }
    139   }
    140 
    141   SortParams(output_node_names);
    142 
    143   for (const std::pair<string, Tensor>& input_node_info :
    144        input_node_info_list) {
    145     GraphTransferInfo::GraphInputNodeInfo& graph_input_node_info =
    146         *graph_transfer_info_.add_graph_input_node_info();
    147     graph_input_node_info.set_name(input_node_info.first);
    148     graph_input_node_info.set_dtype(input_node_info.second.dtype());
    149     for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
    150       graph_input_node_info.add_shape(dim);
    151     }
    152   }
    153 
    154   for (const string& output_node_name : output_node_names) {
    155     const TensorId tid = ParseTensorName(output_node_name);
    156     const string node_name = tid.first.ToString();
    157     const int port = tid.second;
    158     const int node_id = node_name_to_id_cache_map_.at(node_name);
    159     const Node* node = node_name_cache_list_.at(node_id);
    160     CHECK_NOTNULL(node);
    161 
    162     GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info =
    163         *graph_transfer_info_.add_graph_output_node_info();
    164     graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
    165 
    166     // Get output tensor shape type
    167     std::vector<DataType> data_types;
    168     std::vector<TensorShape> shapes;
    169     status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    170         node->attrs(), &data_types, &shapes);
    171     if (status.ok()) {
    172       CHECK(data_types.size() > port);
    173       graph_output_node_info.set_dtype(data_types.at(port));
    174       for (const int64 dim : ToTensorShapeArray(shapes.at(port))) {
    175         graph_output_node_info.add_shape(dim);
    176       }
    177     }
    178   }
    179 
    180   ClearCache();
    181   if (DBG_DUMP_PARAMS) {
    182     DumpNodeTransferParams();
    183   }
    184   if (DBG_DUMP_VERIFICATION_STRING) {
    185     DumpVerificationStringOfNodeTransferParams();
    186   }
    187   return Status();
    188 }
    189 
    190 Status GraphTransferer::LoadGraphFromProtoFile(
    191     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    192     const string& graph_def_path,
    193     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    194     const std::vector<string>& output_node_names, const bool is_text_proto,
    195     const bool shape_inference_for_unknown_shape,
    196     const bool dry_run_for_unknown_shape) {
    197   GraphDef graph_def;
    198   string output;
    199   Status status;
    200   VLOG(1) << "Parse file " << graph_def_path;
    201   if (is_text_proto) {
    202     status = ReadFileToString(Env::Default(), graph_def_path, &output);
    203     if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) {
    204       return errors::InvalidArgument("Cannot parse proto string.");
    205     }
    206   } else {
    207     status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def);
    208   }
    209   if (!status.ok()) {
    210     VLOG(1) << "Failed to load graph " << status;
    211     return status;
    212   }
    213   if (dry_run_for_unknown_shape) {
    214     VLOG(1) << "Dry run graph to obtain shape of nodes";
    215     RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
    216     status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
    217         graph_def, input_node_info_list, true, &tensor_shape_map);
    218     if (!status.ok()) {
    219       return status;
    220     }
    221     for (NodeDef& node_def : *graph_def.mutable_node()) {
    222       TF_CHECK_OK(AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map,
    223                                                            &node_def));
    224     }
    225   }
    226   VLOG(1) << "Load graph with output tensors";
    227   return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list,
    228                             output_node_names,
    229                             shape_inference_for_unknown_shape);
    230 }
    231 
    232 void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
    233   // TODO(satok): optimize complexity
    234   std::unordered_map<int, GraphTransferInfo::NodeInputInfo*> input_map;
    235   for (GraphTransferInfo::NodeInputInfo& input :
    236        *graph_transfer_info_.mutable_node_input_info()) {
    237     input_map.emplace(input.node_id(), &input);
    238   }
    239 
    240   // Setup dependency map placeholder
    241   std::vector<int> output_node_ids;
    242   std::unordered_map<int, std::unordered_set<int>> dependency_map;
    243   for (const GraphTransferInfo::NodeInfo& params :
    244        graph_transfer_info_.node_info()) {
    245     const int node_id = params.node_id();
    246     for (const string& output_node_name : output_node_names) {
    247       if (params.name() == output_node_name) {
    248         output_node_ids.emplace_back(node_id);
    249       }
    250     }
    251 
    252     dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id),
    253                            std::make_tuple());
    254     if (params.input_count() == 0) {
    255       continue;
    256     }
    257     CHECK_EQ(input_map.count(node_id), 1);
    258     for (const GraphTransferInfo::NodeInput& node_input :
    259          input_map.at(node_id)->node_input()) {
    260       dependency_map.at(node_id).emplace(node_input.node_id());
    261     }
    262   }
    263 
    264   // Create dependency map traversed from output nodes
    265   std::unordered_set<int> completed;
    266   for (int output_node_id : output_node_ids) {
    267     FillDependencyRec(output_node_id, dependency_map, completed);
    268   }
    269 
    270   std::sort(graph_transfer_info_.mutable_node_info()->begin(),
    271             graph_transfer_info_.mutable_node_info()->end(),
    272             TransferParamsComparator(dependency_map));
    273 }
    274 
    275 void GraphTransferer::EnableStrictCheckMode(const bool enable) {
    276   strict_check_mode_ = enable;
    277 }
    278 
    279 void GraphTransferer::SetSerializedGraphTransferInfo(
    280     const string& serialized_proto) {
    281   graph_transfer_info_.ParseFromString(serialized_proto);
    282 }
    283 
    284 const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
    285   return graph_transfer_info_;
    286 }
    287 
    288 GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
    289   return graph_transfer_info_;
    290 }
    291 
    292 void GraphTransferer::CacheNode(const Node& node) {
    293   if (node_name_to_id_cache_map_.count(node.name()) > 0) {
    294     return;
    295   }
    296   node_name_cache_list_.emplace_back(&node);
    297   const int node_id = node_name_cache_list_.size() - 1;
    298   bool emplace_succeeded = false;
    299   std::tie(std::ignore, emplace_succeeded) =
    300       node_name_to_id_cache_map_.emplace(node.name(), node_id);
    301   CHECK(emplace_succeeded);
    302 }
    303 
    304 bool GraphTransferer::AreAllInputsCached(const Node& node) const {
    305   for (const Node* const input_node : node.in_nodes()) {
    306     if (node_name_to_id_cache_map_.count(input_node->name()) <= 0) {
    307       VLOG(1) << "input_node " << input_node->name() << " of " << node.name()
    308               << " is not cached yet.";
    309       return false;
    310     }
    311   }
    312   return true;
    313 }
    314 
    315 Status GraphTransferer::TransformGraphToAddAggregatedInputNode(
    316     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    317     Graph* graph, ShapeRefiner* shape_refiner) {
    318   // Transform a remote fused graph to add an aggregated input node which takes
    319   // all inputs of the remote graph.
    320   DataTypeVector input_data_types;
    321   std::vector<DataType> data_types;
    322   std::vector<TensorShape> shapes;
    323   std::vector<string> input_nodes;
    324   for (int i = 0; i < input_node_info_list.size(); ++i) {
    325     Node* node = FindMutableNodeByName(input_node_info_list.at(i).first, graph);
    326     CHECK_NOTNULL(node);
    327     input_nodes.emplace_back(node->name());
    328     input_data_types.emplace_back(input_node_info_list.at(i).second.dtype());
    329     data_types.emplace_back(input_node_info_list.at(i).second.dtype());
    330     shapes.emplace_back(input_node_info_list.at(i).second.shape());
    331   }
    332 
    333   NodeDef input_node_def;
    334   auto builder =
    335       NodeBuilder(AGGREGATED_INPUT_NODE_NAME, "RemoteFusedGraphExecute")
    336           .Input(std::vector<NodeBuilder::NodeOut>{})
    337           .Attr("Tinputs", DataTypeVector{})
    338           .Attr("Toutputs", input_data_types)
    339           .Attr("serialized_remote_fused_graph_execute_info", "")
    340           .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES,
    341                 data_types)
    342           .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, shapes);
    343 
    344   Node* input_node;
    345   TF_RETURN_IF_ERROR(builder.Finalize(graph, &input_node));
    346   CHECK_NOTNULL(input_node);
    347 
    348   bool refined;
    349   TF_RETURN_IF_ERROR(
    350       shape_refiner->UpdateNode(input_node, false /* relax */, &refined));
    351 
    352   shape_inference::InferenceContext* context =
    353       shape_refiner->GetContext(input_node);
    354   for (int i = 0; i < input_node_info_list.size(); ++i) {
    355     shape_inference::ShapeHandle handle;
    356     TF_RETURN_IF_ERROR(context->MakeShapeFromTensorShape(
    357         input_node_info_list.at(i).second.shape(), &handle));
    358     TF_RETURN_IF_ERROR(shape_refiner->SetShape(input_node, i, handle));
    359   }
    360 
    361   // Cache the aggregate input node first as it's consumed first.
    362   CacheNode(*input_node);
    363 
    364   std::vector<Node*> original_input_nodes(input_nodes.size());
    365 
    366   for (int i = 0; i < input_nodes.size(); ++i) {
    367     const string& node_name = input_nodes.at(i);
    368     Node* original_input_node = FindMutableNodeByName(node_name, graph);
    369     CHECK_NOTNULL(original_input_node);
    370     CHECK_EQ(1, original_input_node->num_outputs());  // replaced by identity.
    371     Node* created_node;
    372     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
    373         node_name, AGGREGATED_INPUT_NODE_NAME, i, data_types.at(i), graph,
    374         &created_node));
    375     CHECK_NOTNULL(created_node);
    376     std::vector<DataType> data_types;
    377     std::vector<TensorShape> shapes;
    378     Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    379         original_input_node->attrs(), &data_types, &shapes);
    380     if (status.ok()) {
    381       created_node->AddAttr(
    382           RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types);
    383       created_node->AddAttr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES,
    384                             shapes);
    385     }
    386     for (const Edge* out_edge : original_input_node->out_edges()) {
    387       Node* dst = out_edge->dst();
    388       int dst_port = out_edge->dst_input();
    389       // Unused edge will be removed when removing node.
    390       graph->AddEdge(created_node, 0, dst, dst_port);
    391     }
    392     original_input_nodes[i] = original_input_node;
    393 
    394     TF_RETURN_IF_ERROR(
    395         shape_refiner->UpdateNode(created_node, false /* relax */, &refined));
    396 
    397     shape_inference::InferenceContext* context =
    398         shape_refiner->GetContext(created_node);
    399     CHECK_NOTNULL(context);
    400 
    401     // Cache replaced input node next to the aggregated input node.
    402     CacheNode(*created_node);
    403   }
    404 
    405   // Remove original input nodes after adding new input nodes to avoid
    406   // reusing same pointer in Graph.
    407   for (Node* original_input_node : original_input_nodes) {
    408     graph->RemoveNode(original_input_node);
    409   }
    410 
    411   return Status::OK();
    412 }
    413 
    414 Status GraphTransferer::RegisterNode(
    415     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    416     const ShapeRefiner& shape_refiner, const Node& node,
    417     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    418     const std::vector<string>& output_node_names) {
    419   VLOG(1) << "Register node: " << node.name() << ", " << std::hex
    420           << node_name_to_id_cache_map_.at(node.name());
    421   if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) {
    422     // Just ignore sink and source
    423     return Status::OK();
    424   } else if (node.name() == AGGREGATED_INPUT_NODE_NAME) {
    425     RegisterInputNode(ops_definitions, shape_refiner, node);
    426     return Status::OK();
    427   } else if (node.IsConstant()) {
    428     RegisterConstantNode(shape_refiner, node);
    429   } else if (IsPadNode(node)) {
    430     RegisterPadNode(ops_definitions, shape_refiner, node);
    431   } else if (HasPaddingAndStrides(node)) {
    432     RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node);
    433   } else if (NeedsToAddRank(node)) {
    434     RegisterNodeWithRank(ops_definitions, shape_refiner, node);
    435   } else if (IsNodeFlattenReshape(node, shape_refiner)) {
    436     RegisterFlattenNode(ops_definitions, shape_refiner, node);
    437   } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) !=
    438              IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
    439     // TODO(satok): Set correct data type if it's given.
    440     RegisterGenericNode(ops_definitions, shape_refiner, node);
    441   } else {
    442     return errors::InvalidArgument(node.type_string() +
    443                                    " has not been implemented yet.");
    444   }
    445 
    446   return Status::OK();
    447 }
    448 
    449 void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
    450                                            const Node& node) {
    451   VLOG(1) << "Register constant node: " << node.name();
    452   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    453   const int id = node_name_to_id_cache_map_[node.name()];
    454   const int output_node_size = node.num_outputs();
    455   CHECK_EQ(output_node_size, 1);
    456   // TODO(satok): support multiple outputs?
    457   const int output_index = 0;
    458   const DataType dt = node.output_type(output_index);
    459   const size_t max_bytes_per_data = DataTypeSize(dt);
    460   CHECK_GT(max_bytes_per_data, 0)
    461       << "dt = " << dt << ", " + DataTypeString(dt) << ", "
    462       << max_bytes_per_data << ", " << static_cast<int>(DataTypeSize(dt))
    463       << ",,,,,,,";
    464   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
    465   shape_inference::ShapeHandle shape_handle = context->output(output_index);
    466   const shape_inference::DimensionHandle num_elements_dim =
    467       context->NumElements(shape_handle);
    468   std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
    469   int data_size;
    470   // Shape of constant node must be known
    471   CHECK(context->ValueKnown(num_elements_dim));
    472   const int64 num_output_elements = context->Value(num_elements_dim);
    473   data_size = max_bytes_per_data * num_output_elements;
    474   shape_array = BuildShapeArray(shape_handle, context);
    475 
    476   GraphTransferInfo::ConstNodeInfo& const_node_info =
    477       *graph_transfer_info_.add_const_node_info();
    478   const_node_info.set_name(node.name());
    479   const_node_info.set_node_id(id);
    480   // TODO(satok): Make this generic. Never assume rank is 4.
    481   CHECK_EQ(4, SHAPE_ARRAY_SIZE);
    482   const_node_info.add_shape(shape_array[0]);
    483   const_node_info.add_shape(shape_array[1]);
    484   const_node_info.add_shape(shape_array[2]);
    485   const_node_info.add_shape(shape_array[3]);
    486   const TensorProto* proto = nullptr;
    487   TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto));
    488   Tensor const_tensor;
    489   TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
    490 
    491   const_node_info.set_dtype(const_tensor.dtype());
    492   if (data_size > 0) {
    493     const_node_info.set_data(const_tensor.tensor_data().data(), data_size);
    494   }
    495 }
    496 
    497 int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
    498   VLOG(1) << "Cache constant shape.";
    499   // TODO(satok): Handle non-4dim strides
    500   CHECK_EQ(shape.size(), 4);
    501   const string shape_name = CONST_SHAPE_PREFIX + ToString(shape.at(0)) + 'x' +
    502                             ToString(shape.at(1)) + 'x' +
    503                             ToString(shape.at(2)) + 'x' + ToString(shape.at(3));
    504   if (node_name_to_id_cache_map_.count(shape_name) <= 0) {
    505     node_name_cache_list_.emplace_back(nullptr);
    506     const int id = node_name_cache_list_.size() - 1;
    507     node_name_to_id_cache_map_.emplace(shape_name, id);
    508     GraphTransferInfo::ConstNodeInfo& const_node_info =
    509         *graph_transfer_info_.add_const_node_info();
    510     const_node_info.set_name(shape_name);
    511     const_node_info.set_node_id(id);
    512     // TODO(satok): Make this generic. Never assume rank is 5.
    513     const_node_info.add_shape(static_cast<int64>(shape[0]));
    514     const_node_info.add_shape(static_cast<int64>(shape[1]));
    515     const_node_info.add_shape(static_cast<int64>(shape[2]));
    516     const_node_info.add_shape(static_cast<int64>(shape[3]));
    517   }
    518   return node_name_to_id_cache_map_[shape_name];
    519 }
    520 
    521 int GraphTransferer::RegisterConstTensor(const Tensor& tensor,
    522                                          const string& suffix) {
    523   VLOG(1) << "Cache const tensor.";
    524   const int dims = tensor.shape().dims();
    525   CHECK(dims <= 4);
    526   const string node_name = strings::StrCat(CONST_TENSOR_PREFIX, "_", suffix);
    527   if (node_name_to_id_cache_map_.count(node_name) <= 0) {
    528     node_name_cache_list_.emplace_back(nullptr);
    529     const int id = node_name_cache_list_.size() - 1;
    530     node_name_to_id_cache_map_.emplace(node_name, id);
    531     GraphTransferInfo::ConstNodeInfo& const_node_info =
    532         *graph_transfer_info_.add_const_node_info();
    533     const_node_info.set_name(node_name);
    534     const_node_info.set_node_id(id);
    535     CHECK_EQ(4, SHAPE_ARRAY_SIZE);
    536     for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) {
    537       if (i < SHAPE_ARRAY_SIZE - dims) {
    538         const_node_info.add_shape(1);
    539       } else {
    540         const_node_info.add_shape(
    541             tensor.shape().dim_size(i - (SHAPE_ARRAY_SIZE - dims)));
    542       }
    543     }
    544     const_node_info.set_dtype(tensor.dtype());
    545     const_node_info.set_data(tensor.tensor_data().data(),
    546                              tensor.tensor_data().size());
    547   }
    548   return node_name_to_id_cache_map_[node_name];
    549 }
    550 
    551 int GraphTransferer::RegisterConstScalar(const DataType dt, const int val,
    552                                          const int dst_id,
    553                                          const int dst_input_count) {
    554   VLOG(1) << "Cache const.";
    555   const string val_name =
    556       CONST_VAL_PREFIX + ToString(dst_id) + '_' + ToString(dst_input_count);
    557   if (node_name_to_id_cache_map_.count(val_name) <= 0) {
    558     node_name_cache_list_.emplace_back(nullptr);
    559     const int id = node_name_cache_list_.size() - 1;
    560     node_name_to_id_cache_map_.emplace(val_name, id);
    561     GraphTransferInfo::ConstNodeInfo& const_node_info =
    562         *graph_transfer_info_.add_const_node_info();
    563     const_node_info.set_name(val_name);
    564     const_node_info.set_node_id(id);
    565     // TODO(satok): Do not assume rank is 4 here.
    566     const_node_info.add_shape(static_cast<int64>(1));
    567     const_node_info.add_shape(static_cast<int64>(1));
    568     const_node_info.add_shape(static_cast<int64>(1));
    569     const_node_info.add_shape(static_cast<int64>(1));
    570     const_node_info.set_data(&val, DataTypeSize(dt));
    571   }
    572   return node_name_to_id_cache_map_[val_name];
    573 }
    574 
    575 bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
    576   auto attrs = node.attrs();
    577   return attrs.Find(PADDING_ATTR_NAME) != nullptr &&
    578          attrs.Find(STRIDES_ATTR_NAME) != nullptr;
    579 }
    580 
    581 bool GraphTransferer::NeedsToAddRank(const Node& node) {
    582   const StringPiece op_type(node.type_string());
    583   if (op_type == "Transpose" || op_type == "ExpandDims") {
    584     return true;
    585   }
    586   return false;
    587 }
    588 
    589 bool GraphTransferer::IsPadNode(const Node& node) {
    590   const StringPiece op_type(node.type_string());
    591   if (op_type == "Pad") {
    592     return true;
    593   }
    594   return false;
    595 }
    596 
    597 bool GraphTransferer::IsNodeFlattenReshape(const Node& node,
    598                                            const ShapeRefiner& shape_refiner) {
    599   // Check if node is reshape op
    600   if (node.type_string() != RESHAPE_NODE_TYPE_STRING) {
    601     return false;
    602   }
    603 
    604   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
    605   // Check if output count is valid
    606   if (context->num_outputs() != 1) {
    607     return false;
    608   }
    609 
    610   shape_inference::ShapeHandle shape_handle = context->output(0);
    611   std::array<int64, SHAPE_ARRAY_SIZE> shape_array;
    612   const shape_inference::DimensionHandle dim_handle =
    613       context->NumElements(shape_handle);
    614 
    615   // Obtain shape of output of node
    616   if (context->ValueKnown(dim_handle)) {
    617     shape_array = BuildShapeArray(shape_handle, context);
    618   } else {
    619     std::vector<TensorShape> shapes;
    620     TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    621         node.attrs(), nullptr, &shapes));
    622 
    623     // Number of outputs should be 1 for reshape node.
    624     CHECK_EQ(1, shapes.size());
    625     shape_array = ToTensorShapeArray(shapes.at(0));
    626   }
    627 
    628   // check if reshape op just does flatten
    629   if (shape_array[0] == 1 && shape_array[1] == 1 && shape_array[2] == 1) {
    630     return true;
    631   } else {
    632     return false;
    633   }
    634 }
    635 
    636 void GraphTransferer::RegisterNodeWithPaddingAndStrides(
    637     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    638     const ShapeRefiner& shape_refiner, const Node& node) {
    639   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    640   const int id = node_name_to_id_cache_map_[node.name()];
    641   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
    642   CHECK(node.attrs().Find(PADDING_ATTR_NAME));
    643   // TODO(satok): Use context->GetAttr(...) instead?
    644   Padding padding;
    645   TF_CHECK_OK(context->GetAttr(PADDING_ATTR_NAME, &padding));
    646   CHECK(node.attrs().Find(STRIDES_ATTR_NAME));
    647   std::vector<int32> strides;
    648   TF_CHECK_OK(context->GetAttr(STRIDES_ATTR_NAME, &strides));
    649   const int stride_id = RegisterConstantShape(strides);
    650   std::vector<int> extra_inputs{stride_id};
    651   if (node.attrs().Find(KSIZE_ATTR_NAME)) {
    652     std::vector<int32> kernel_sizes;
    653     TF_CHECK_OK(context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes));
    654     const int ksize_id = RegisterConstantShape(kernel_sizes);
    655     extra_inputs.insert(extra_inputs.begin(), ksize_id);
    656   }
    657   // TODO(satok): Set correct data type if it's given.
    658   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
    659   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
    660       << "Op " << node.type_string() << " not found in map(id = " << op_type_id
    661       << ")";
    662   // Safety check of padding id
    663   CHECK(padding == Padding::VALID ? 1 : 2);
    664   AppendNodeParamsWithIoParams(
    665       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    666       static_cast<int>(padding), node.num_inputs(), extra_inputs,
    667       node.num_outputs(), true /* append_input */, true /* append_output */);
    668 }
    669 
    670 void GraphTransferer::RegisterNodeWithRank(
    671     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    672     const ShapeRefiner& shape_refiner, const Node& node) {
    673   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    674   const int id = node_name_to_id_cache_map_[node.name()];
    675   shape_inference::InferenceContext* context = shape_refiner.GetContext(&node);
    676   const Node* input0_node;
    677   TF_CHECK_OK(node.input_node(0, &input0_node));
    678   CHECK_NOTNULL(input0_node);
    679   std::vector<TensorShape> shapes;
    680   Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    681       input0_node->attrs(), nullptr, &shapes);
    682   CHECK_EQ(1, shapes.size()) << "Output size should be 1.";
    683   const int const_val_id =
    684       RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs());
    685   std::vector<int> extra_inputs{const_val_id};
    686   // TODO(satok): Set correct data type if it's given.
    687   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
    688   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
    689       << "Op " << node.type_string() << " not found in map(id = " << op_type_id
    690       << ")";
    691   bool keep_dims = false;
    692   int padding_id = PADDING_NA_ID;
    693   if (context->GetAttr(KEEP_DIMS_ATTR_NAME, &keep_dims).ok()) {
    694     padding_id = keep_dims ? Padding::SAME : Padding::VALID;
    695   }
    696 
    697   AppendNodeParamsWithIoParams(
    698       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    699       padding_id, node.num_inputs(), extra_inputs, node.num_outputs(),
    700       true /* append_input */, true /* append_output */);
    701 }
    702 
    703 void GraphTransferer::RegisterPadNode(
    704     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    705     const ShapeRefiner& shape_refiner, const Node& node) {
    706   static constexpr int PAD_WIDTH = 4;
    707   static constexpr int PAD_HEIGHT = 2;
    708   VLOG(1) << "Register generic node: " << node.name();
    709   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    710   const int id = node_name_to_id_cache_map_[node.name()];
    711 
    712   // TODO(satok): Set correct data type if it's given.
    713   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
    714   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
    715 
    716   CHECK_EQ(2, node.num_inputs());
    717 
    718   GraphTransferInfo::NodeInputInfo& node_input_info =
    719       *graph_transfer_info_.add_node_input_info();
    720   node_input_info.set_node_id(id);
    721 
    722   AddNodeInputByInputIndex(node, 0, &node_input_info);
    723 
    724   const Edge* edge = nullptr;
    725   TF_CHECK_OK(node.input_edge(1, &edge));
    726   const Node* input_node = edge->src();
    727   CHECK_NOTNULL(input_node);
    728   CHECK(input_node->IsConstant());
    729 
    730   const TensorProto* tensor_proto = nullptr;
    731   TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &tensor_proto));
    732   CHECK_NOTNULL(tensor_proto);
    733   Tensor const_tensor;
    734   TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor));
    735   CHECK_EQ(2, const_tensor.shape().dims());
    736   CHECK_EQ(PAD_HEIGHT, const_tensor.shape().dim_size(1));
    737   if (const_tensor.shape().dim_size(0) == PAD_WIDTH) {
    738     AddNodeInputByInputIndex(node, 1, &node_input_info);
    739   } else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) {
    740     const int width = const_tensor.shape().dim_size(0);
    741     const TensorProto* proto = nullptr;
    742     TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &proto));
    743     Tensor const_tensor;
    744     TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor));
    745     CHECK_EQ(DT_INT32, const_tensor.dtype());
    746     // reshape tensor input to be rank 4.
    747     // TODO(satok): Never assume rank is 4.
    748     Tensor new_const_tensor(const_tensor.dtype(), TensorShape{4, 2});
    749     for (int i = 0; i < PAD_HEIGHT; ++i) {
    750       for (int j = 0; j < PAD_WIDTH; ++j) {
    751         if (j < PAD_WIDTH - width) {
    752           new_const_tensor.matrix<int32>()(j, i) = 0;
    753         } else {
    754           new_const_tensor.matrix<int32>()(j, i) =
    755               const_tensor.matrix<int32>()(j - (PAD_WIDTH - width), i);
    756         }
    757       }
    758     }
    759 
    760     const int id = RegisterConstTensor(
    761         new_const_tensor,
    762         strings::StrCat(input_node->name(), "_", node.name(), "_1"));
    763 
    764     GraphTransferInfo::NodeInput& node_input =
    765         *node_input_info.add_node_input();
    766     node_input.set_node_id(id);
    767     node_input.set_output_port(0);
    768   } else {
    769     LOG(FATAL);
    770   }
    771 
    772   AppendNodeParamsWithIoParams(
    773       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    774       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
    775       false /* append_input */, true /* append_output */);
    776 }
    777 
    778 void GraphTransferer::RegisterInputNode(
    779     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    780     const ShapeRefiner& shape_refiner, const Node& node) {
    781   const string op_type = node.type_string();
    782   VLOG(1) << "Register input node: " << node.name() << ", " << op_type;
    783   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    784   const int id = node_name_to_id_cache_map_[node.name()];
    785   // TODO(satok): Set correct data type if it's given.
    786   const int op_type_id = ops_definitions.GetOpIdFor("INPUT", {});
    787   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount())
    788       << "Op" << node.name() << ", " << op_type << " is not supported,"
    789       << op_type_id;
    790   AppendNodeParamsWithIoParams(
    791       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    792       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
    793       true /* append_input */, true /* append_output */);
    794 }
    795 
    796 void GraphTransferer::RegisterFlattenNode(
    797     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    798     const ShapeRefiner& shape_refiner, const Node& node) {
    799   VLOG(1) << "Register flatten node: " << node.name();
    800   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    801   const int id = node_name_to_id_cache_map_[node.name()];
    802   // TODO(satok): Remove dependency to specific type
    803   const string op_type = "FLATTEN";
    804   // TODO(satok): Set correct data type if it's given.
    805   const int op_type_id = ops_definitions.GetOpIdFor(op_type, {});
    806   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
    807 
    808   AppendNodeParamsWithIoParams(
    809       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    810       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
    811       true /* append_input */, true /* append_output */);
    812 }
    813 
    814 void GraphTransferer::RegisterGenericNode(
    815     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    816     const ShapeRefiner& shape_refiner, const Node& node) {
    817   VLOG(1) << "Register generic node: " << node.name();
    818   CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1);
    819   const int id = node_name_to_id_cache_map_[node.name()];
    820   // TODO(satok): Set correct data type if it's given.
    821   const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {});
    822   CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount());
    823 
    824   AppendNodeParamsWithIoParams(
    825       shape_refiner, node, node.name(), id, node.type_string(), op_type_id,
    826       PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(),
    827       true /* append_input */, true /* append_output */);
    828 }
    829 
    830 // TODO(satok): Remove this function.
    831 // TODO(satok): Remove only_register_const_node.
    832 Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
    833     const IRemoteFusedGraphOpsDefinitions& ops_definitions,
    834     const ShapeRefiner& shape_refiner, const Node& node,
    835     const bool only_register_const_node,
    836     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    837     const std::vector<string>& output_node_names) {
    838   if (only_register_const_node && !node.IsConstant()) {
    839     return Status();
    840   }
    841   CHECK(AreAllInputsCached(node));
    842   return RegisterNode(ops_definitions, shape_refiner, node,
    843                       input_node_info_list, output_node_names);
    844 }
    845 
    846 // CAVEAT: Append inputs and outputs params accordingly
    847 void GraphTransferer::AppendNodeParams(const string& name, const int id,
    848                                        const string& type, const int type_id,
    849                                        const int padding, const int inputs_size,
    850                                        const std::vector<int>& extra_inputs,
    851                                        const int outputs_size) {
    852   GraphTransferInfo::NodeInfo& node_info =
    853       *graph_transfer_info_.add_node_info();
    854   node_info.set_name(name);
    855   node_info.set_node_id(id);
    856   node_info.set_type_name(type);
    857   node_info.set_soc_op_id(type_id);
    858   node_info.set_padding_id(padding);
    859   node_info.set_input_count(inputs_size +
    860                             static_cast<int>(extra_inputs.size()));
    861   node_info.set_output_count(static_cast<int>(outputs_size));
    862 }
    863 
    864 void GraphTransferer::AddNodeInputByInputIndex(
    865     const Node& node, const int idx,
    866     GraphTransferInfo::NodeInputInfo* node_input_info) {
    867   const Edge* edge = nullptr;
    868   TF_CHECK_OK(node.input_edge(idx, &edge));
    869   const Node* input_node = edge->src();
    870   CHECK_NOTNULL(input_node);
    871   const int port = edge->src_output();
    872 
    873   const std::string& op_name = input_node->name();
    874   CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
    875   const int src_id = node_name_to_id_cache_map_[op_name];
    876   GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input();
    877   node_input.set_node_id(src_id);
    878   node_input.set_output_port(port);
    879 }
    880 
    881 void GraphTransferer::AppendNodeInputParams(
    882     const int id, const Node& node, const std::vector<int>& extra_inputs) {
    883   VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
    884           << ", " << extra_inputs.size();
    885   GraphTransferInfo::NodeInputInfo& node_input_info =
    886       *graph_transfer_info_.add_node_input_info();
    887   node_input_info.set_node_id(id);
    888   for (int i = 0; i < node.num_inputs(); ++i) {
    889     AddNodeInputByInputIndex(node, i, &node_input_info);
    890   }
    891   for (const int extra_input : extra_inputs) {
    892     GraphTransferInfo::NodeInput& node_input =
    893         *node_input_info.add_node_input();
    894     node_input.set_node_id(extra_input);
    895     node_input.set_output_port(0);
    896   }
    897 }
    898 
    899 void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
    900                                              const int id, const Node& node) {
    901   VLOG(1) << "Append output params: " << node.name() << ", "
    902           << node.num_outputs();
    903   GraphTransferInfo::NodeOutputInfo& node_output_info =
    904       *graph_transfer_info_.add_node_output_info();
    905   node_output_info.set_node_id(id);
    906 
    907   std::vector<DataType> data_types;
    908   std::vector<TensorShape> shapes;
    909   Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    910       node.attrs(), &data_types, &shapes);
    911 
    912   for (int i = 0; i < node.num_outputs(); ++i) {
    913     int data_size = -1;
    914     const int output_index = i;
    915     const DataType dt = node.output_type(output_index);
    916     const size_t max_bytes_per_data = DataTypeSize(dt);
    917 
    918     shape_inference::InferenceContext* context =
    919         shape_refiner.GetContext(&node);
    920 
    921     if (context != nullptr && context->ValueKnown(context->NumElements(
    922                                   context->output(output_index)))) {
    923       const shape_inference::DimensionHandle num_elements_dim =
    924           context->NumElements(context->output(output_index));
    925       const int64 num_output_elements = context->Value(num_elements_dim);
    926       data_size = max_bytes_per_data * num_output_elements;
    927       if (status.ok()) {
    928         TF_CHECK_OK(status);
    929         CHECK_EQ(shapes.at(i).num_elements(), num_output_elements);
    930       }
    931     } else {
    932       TF_CHECK_OK(status);
    933       // Use attribute attached to node
    934       data_size = max_bytes_per_data * shapes.at(i).num_elements();
    935     }
    936     CHECK_GE(data_size, 0);
    937     node_output_info.add_max_byte_size(data_size);
    938   }
    939 }
    940 
    941 void GraphTransferer::AppendNodeParamsWithIoParams(
    942     const ShapeRefiner& shape_refiner, const Node& node, const string& name,
    943     const int id, const string& type, const int type_id, const int padding,
    944     const int inputs_size, const std::vector<int>& extra_inputs,
    945     const int outputs_size, const bool append_input_params,
    946     const bool append_output_params) {
    947   VLOG(1) << "Append node with io params: " << node.name();
    948   if (append_input_params) {
    949     AppendNodeInputParams(id, node, extra_inputs);
    950   }
    951   if (append_output_params) {
    952     AppendNodeOutputParams(shape_refiner, id, node);
    953   }
    954   AppendNodeParams(name, id, type, type_id, padding, inputs_size, extra_inputs,
    955                    outputs_size);
    956 }
    957 
    958 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
    959 GraphTransferer::BuildShapeArray(
    960     const shape_inference::ShapeHandle& shape_handle,
    961     shape_inference::InferenceContext* context) {
    962   switch (context->Rank(shape_handle)) {
    963     case 0:
    964       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
    965     case 1:
    966       return std::array<int64, SHAPE_ARRAY_SIZE>{
    967           {1, 1, 1, context->Value(context->Dim(shape_handle, 0))}};
    968     case 2:
    969       return std::array<int64, SHAPE_ARRAY_SIZE>{
    970           {1, 1, context->Value(context->Dim(shape_handle, 0)),
    971            context->Value(context->Dim(shape_handle, 1))}};
    972     case 3:
    973       return std::array<int64, SHAPE_ARRAY_SIZE>{
    974           {1, context->Value(context->Dim(shape_handle, 0)),
    975            context->Value(context->Dim(shape_handle, 1)),
    976            context->Value(context->Dim(shape_handle, 2))}};
    977     case 4:
    978       return std::array<int64, SHAPE_ARRAY_SIZE>{
    979           {context->Value(context->Dim(shape_handle, 0)),
    980            context->Value(context->Dim(shape_handle, 1)),
    981            context->Value(context->Dim(shape_handle, 2)),
    982            context->Value(context->Dim(shape_handle, 3))}};
    983     default:
    984       // TODO(satok): Support more ranks?
    985       LOG(FATAL);
    986       return std::array<int64, SHAPE_ARRAY_SIZE>();
    987   }
    988 }
    989 
    990 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>
    991 GraphTransferer::ToTensorShapeArray(const TensorShape& shape) {
    992   switch (shape.dims()) {
    993     case 0:
    994       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}};
    995     case 1:
    996       return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, shape.dim_size(0)}};
    997     case 2:
    998       return std::array<int64, SHAPE_ARRAY_SIZE>{
    999           {1, 1, shape.dim_size(0), shape.dim_size(1)}};
   1000     case 3:
   1001       return std::array<int64, SHAPE_ARRAY_SIZE>{
   1002           {1, shape.dim_size(0), shape.dim_size(1), shape.dim_size(2)}};
   1003     case 4:
   1004       return std::array<int64, SHAPE_ARRAY_SIZE>{
   1005           {shape.dim_size(0), shape.dim_size(1), shape.dim_size(2),
   1006            shape.dim_size(3)}};
   1007     default:
   1008       // TODO(satok): Support more ranks?
   1009       LOG(FATAL);
   1010       return std::array<int64, SHAPE_ARRAY_SIZE>();
   1011   }
   1012 }
   1013 
   1014 /* static */ string GraphTransferer::ToPaddingDebugString(const int padding) {
   1015   switch (padding) {
   1016     case 0:
   1017       return "NN_PAD_NA";
   1018     case Padding::VALID:
   1019       return "NN_PAD_VALID";
   1020     case Padding::SAME:
   1021       return "NN_PAD_SAME";
   1022     default:
   1023       LOG(FATAL);
   1024       return "";
   1025   }
   1026 }
   1027 
   1028 GraphTransferer::TransferParamsComparator::TransferParamsComparator(
   1029     const std::unordered_map<int, std::unordered_set<int>>& dep_map)
   1030     : dependency_map_(dep_map) {}
   1031 
   1032 bool GraphTransferer::TransferParamsComparator::operator()(
   1033     const GraphTransferInfo::NodeInfo& obj0,
   1034     const GraphTransferInfo::NodeInfo& obj1) {
   1035   const int node_id0 = obj0.node_id();
   1036   const int node_id1 = obj1.node_id();
   1037   bool obj0_uses_obj1 = false;
   1038   if (dependency_map_.count(node_id0) > 0) {
   1039     obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0;
   1040   }
   1041   bool obj1_uses_obj0 = false;
   1042   if (dependency_map_.count(node_id1) > 0) {
   1043     obj1_uses_obj0 = dependency_map_.at(node_id1).count(node_id0) > 0;
   1044   }
   1045   CHECK(!obj0_uses_obj1 || !obj1_uses_obj0);
   1046   if (obj0_uses_obj1) {
   1047     return false;
   1048   } else if (obj1_uses_obj0) {
   1049     return true;
   1050   }
   1051   // If there is no dependency between two nodes, it expects that
   1052   // the execution order follows node id order.
   1053   return node_id0 < node_id1;
   1054 }
   1055 
   1056 /* static */ void GraphTransferer::FillDependencyRec(
   1057     const int node_id,
   1058     std::unordered_map<int, std::unordered_set<int>>& dep_map,
   1059     std::unordered_set<int>& completed) {
   1060   if (dep_map.count(node_id) == 0 || dep_map.at(node_id).empty() ||
   1061       completed.count(node_id) == 1) {
   1062     return;
   1063   }
   1064   CHECK_EQ(dep_map.count(node_id), 1);
   1065 
   1066   // Complete children's dependency map
   1067   for (int child_node_id : dep_map.at(node_id)) {
   1068     CHECK(child_node_id != node_id);
   1069     if (completed.count(child_node_id) != 0) {
   1070       continue;
   1071     }
   1072     FillDependencyRec(child_node_id, dep_map, completed);
   1073   }
   1074 
   1075   // Find additional depending ids
   1076   std::vector<int> depending_ids;
   1077   for (int child_node_id : dep_map.at(node_id)) {
   1078     if (dep_map.count(child_node_id) == 0) {
   1079       continue;
   1080     }
   1081     for (int depending_id : dep_map.at(child_node_id)) {
   1082       depending_ids.emplace_back(depending_id);
   1083     }
   1084   }
   1085 
   1086   // Insert additional depending ids
   1087   for (int depending_id : depending_ids) {
   1088     if (dep_map.at(node_id).count(depending_id) == 0) {
   1089       dep_map.at(node_id).emplace(depending_id);
   1090     }
   1091   }
   1092 
   1093   // DP: Record completed node id
   1094   completed.emplace(node_id);
   1095 }
   1096 
   1097 /* static */ Status GraphTransferer::MakeTensorFromProto(
   1098     const TensorProto& tensor_proto, Tensor* tensor) {
   1099   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
   1100     Tensor parsed(tensor_proto.dtype());
   1101     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
   1102       *tensor = parsed;
   1103       return Status::OK();
   1104     }
   1105   }
   1106   return errors::InvalidArgument("Cannot parse tensor from proto: ",
   1107                                  tensor_proto.DebugString());
   1108 }
   1109 
   1110 void GraphTransferer::ClearCache() {
   1111   node_name_cache_list_.clear();
   1112   node_name_to_id_cache_map_.clear();
   1113 }
   1114 
   1115 void GraphTransferer::DumpNodeTransferParams() const {
   1116   LOG(INFO) << "*** Const Nodes ***";
   1117   for (const GraphTransferInfo::ConstNodeInfo& params :
   1118        graph_transfer_info_.const_node_info()) {
   1119     // TODO(satok): Stop assuming shape size is 4.
   1120     CHECK_EQ(params.shape_size(), 4);
   1121     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
   1122               << "\" (Const)";
   1123     LOG(INFO) << "  shape: " << params.shape(0) << params.shape(1)
   1124               << params.shape(2) << params.shape(3);
   1125     LOG(INFO) << "  data_name: "
   1126               << (params.data().length() <= 0
   1127                       ? ""
   1128                       : DATA_NODE_PREFIX + ToString(params.node_id()));
   1129     LOG(INFO) << "  data_size: " << params.data().length() << " bytes"
   1130               << " ]";
   1131   }
   1132   LOG(INFO) << "******\n";
   1133   LOG(INFO) << "*** Op Nodes ***";
   1134   for (const GraphTransferInfo::NodeInfo& params :
   1135        graph_transfer_info_.node_info()) {
   1136     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
   1137     LOG(INFO) << "  type: " << params.type_name();
   1138     LOG(INFO) << "  padding: " << ToPaddingDebugString(params.padding_id());
   1139     LOG(INFO) << "  inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id())
   1140               << ", size = " << params.input_count();
   1141     LOG(INFO) << "  outputs: "
   1142               << (params.output_count() <= 0
   1143                       ? NULL_OUTPUT_NAME
   1144                       : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
   1145               << ", size = " << params.output_count() << " ]";
   1146   }
   1147   LOG(INFO) << "******\n";
   1148   LOG(INFO) << "*** Node input params ***";
   1149   for (const GraphTransferInfo::NodeInputInfo& params :
   1150        graph_transfer_info_.node_input_info()) {
   1151     LOG(INFO) << "[ " << params.node_id() << " ]";
   1152     for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
   1153       LOG(INFO) << "    src node id = " << node_input.node_id()
   1154                 << ", output port = " << node_input.output_port();
   1155     }
   1156   }
   1157   LOG(INFO) << "******\n";
   1158   LOG(INFO) << "*** Node output params ***";
   1159   for (const GraphTransferInfo::NodeOutputInfo& params :
   1160        graph_transfer_info_.node_output_info()) {
   1161     LOG(INFO) << "[ " << params.node_id() << " ]";
   1162     for (const int max_size : params.max_byte_size()) {
   1163       LOG(INFO) << "    max_size = " << max_size;
   1164     }
   1165   }
   1166   LOG(INFO) << "******\n";
   1167 }
   1168 
   1169 void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
   1170   for (const GraphTransferInfo::ConstNodeInfo& params :
   1171        graph_transfer_info_.const_node_info()) {
   1172     std::stringstream sstream;
   1173     // TODO(satok): Stop assuming shape size is 4.
   1174     CHECK_EQ(params.shape_size(), 4);
   1175     sstream << "---(CONST) [" << std::hex << params.node_id() << std::dec << ","
   1176             << params.shape(0) << "," << params.shape(1) << ","
   1177             << params.shape(2) << "," << params.shape(3) << ","
   1178             << (params.data().length() <= 0
   1179                     ? ""
   1180                     : DATA_NODE_PREFIX + ToString(params.node_id()))
   1181             << "," << params.data().length() << "," << params.name() << "]";
   1182     LOG(INFO) << sstream.str();
   1183   }
   1184   LOG(INFO) << "Const node count = "
   1185             << graph_transfer_info_.const_node_info_size();
   1186   for (const GraphTransferInfo::NodeInfo& params :
   1187        graph_transfer_info_.node_info()) {
   1188     std::stringstream sstream;
   1189     sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
   1190             << params.node_id() << std::dec << "," << params.soc_op_id() << ","
   1191             << ToPaddingDebugString(params.padding_id()) << ","
   1192             << INPUTS_NODE_PREFIX + ToString(params.node_id()) << ","
   1193             << params.input_count() << ","
   1194             << (params.output_count() <= 0
   1195                     ? NULL_OUTPUT_NAME
   1196                     : (OUTPUTS_NODE_PREFIX + ToString(params.node_id())))
   1197             << "," << params.output_count() << "," << params.type_name() << "]";
   1198     LOG(INFO) << sstream.str();
   1199   }
   1200   LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size();
   1201   for (const GraphTransferInfo::NodeInputInfo& params :
   1202        graph_transfer_info_.node_input_info()) {
   1203     std::stringstream sstream;
   1204     sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
   1205     for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
   1206       sstream << "," << std::hex << node_input.node_id() << std::dec << ","
   1207               << node_input.output_port();
   1208     }
   1209     sstream << "]";
   1210     LOG(INFO) << sstream.str();
   1211   }
   1212   LOG(INFO) << "Input params count = "
   1213             << graph_transfer_info_.node_input_info_size();
   1214   for (const GraphTransferInfo::NodeOutputInfo& params :
   1215        graph_transfer_info_.node_output_info()) {
   1216     std::stringstream sstream;
   1217     sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
   1218     for (const int max_size : params.max_byte_size()) {
   1219       sstream << "," << max_size;
   1220     }
   1221     sstream << "]";
   1222     LOG(INFO) << sstream.str();
   1223   }
   1224   LOG(INFO) << "Output params count = "
   1225             << graph_transfer_info_.node_output_info_size();
   1226 }
   1227 
   1228 }  // namespace tensorflow
   1229