Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     17 
     18 #include <algorithm>
     19 #include <queue>
     20 #include <utility>
     21 
     22 #include "tensorflow/core/common_runtime/shape_refiner.h"
     23 #include "tensorflow/core/framework/node_def_util.h"
     24 #include "tensorflow/core/framework/tensor.pb.h"
     25 #include "tensorflow/core/framework/tensor_shape.pb.h"
     26 #include "tensorflow/core/graph/algorithm.h"
     27 #include "tensorflow/core/graph/node_builder.h"
     28 #include "tensorflow/core/public/session.h"
     29 #include "tensorflow/core/public/session_options.h"
     30 
     31 namespace tensorflow {
     32 namespace {
     33 const Node* FindNodeByName(const string& name, const Graph& graph) {
     34   for (const Node* node : graph.nodes()) {
     35     CHECK_NOTNULL(node);
     36     if (node->name() == name) {
     37       return node;
     38     }
     39   }
     40   return nullptr;
     41 }
     42 
     43 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
     44     const std::vector<string>& node_names_and_ports) {
     45   std::unordered_set<string> retval;
     46   for (const string& node_name_and_port : node_names_and_ports) {
     47     const TensorId tid = ParseTensorName(node_name_and_port);
     48     retval.emplace(tid.first.ToString());
     49   }
     50   return retval;
     51 }
     52 
     53 Node* FindMutableNodeByName(const string& name, Graph* graph) {
     54   for (Node* node : graph->nodes()) {
     55     if (node != nullptr && node->name() == name) {
     56       return node;
     57     }
     58   }
     59   return nullptr;
     60 }
     61 
     62 const NodeDef* FindNodeDefByName(const string& input,
     63                                  const GraphDef& graph_def) {
     64   const TensorId tid = ParseTensorName(input);
     65   const string name = tid.first.ToString();
     66   for (const NodeDef& node_def : graph_def.node()) {
     67     if (node_def.name() == name) {
     68       return &node_def;
     69     }
     70   }
     71   return nullptr;
     72 }
     73 
     74 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
     75                     TensorId* tid) {
     76   CHECK_NOTNULL(tid);
     77   *tid = ParseTensorName(node_name_and_port);
     78   if (node_def.name() == tid->first.ToString()) {
     79     return true;
     80   }
     81   return false;
     82 }
     83 
     84 bool ContainsSameTensorId(const string& tensor_name,
     85                           const std::vector<string>& tensor_names) {
     86   const TensorId tid0 = ParseTensorName(tensor_name);
     87   for (const string& name : tensor_names) {
     88     const TensorId tid1 = ParseTensorName(name);
     89     if (tid0.first == tid1.first && tid0.second == tid1.second) {
     90       return true;
     91     }
     92   }
     93   return false;
     94 }
     95 
     96 void AppendDeliminator(string* str) {
     97   CHECK_NOTNULL(str);
     98   if (!str->empty()) {
     99     *str += ":";
    100   }
    101 }
    102 
    103 void ConvertMapToVector(const std::unordered_map<int, string>& in,
    104                         std::vector<string>* out) {
    105   CHECK_NOTNULL(out);
    106   out->resize(in.size());
    107   for (size_t i = 0; i < in.size(); ++i) {
    108     CHECK(in.count(i) > 0);
    109     out->at(i) = in.at(i);
    110   }
    111 }
    112 
    113 string DumpGraphDef(const GraphDef& graph_def) {
    114   string out;
    115   for (const NodeDef& node : graph_def.node()) {
    116     out += strings::StrCat("node: ", node.name(), "\n    input: ");
    117     for (const string& input : node.input()) {
    118       out += strings::StrCat(input, ", ");
    119     }
    120     out += "\n";
    121   }
    122   return out;
    123 }
    124 
    125 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
    126   string out;
    127   out += "Nodes:\n";
    128   for (const string& str : std::get<0>(cluster)) {
    129     out += str + ", ";
    130   }
    131   out += "\nInput border:\n";
    132   for (const string& str : std::get<1>(cluster)) {
    133     out += str + ", ";
    134   }
    135   out += "\nOutput border:\n";
    136   for (const string& str : std::get<2>(cluster)) {
    137     out += str + ", ";
    138   }
    139   return out;
    140 }
    141 
    142 }  // namespace
    143 
    144 /* static */ constexpr const char* const
    145     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
    146 /* static */ constexpr const char* const
    147     RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
    148 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
    149     ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
    150 /* static */ constexpr const char* const
    151     RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
    152 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
    153     TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
    154 /* static */ constexpr const char* const
    155     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
    156 /* static */ constexpr const char* const
    157     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
    158 /* static */ constexpr const char* const
    159     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
    160 /* static */ constexpr const char* const
    161     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
    162 /* static */ constexpr const char* const
    163     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
    164 /* static */ constexpr const char* const
    165     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
    166 /* static */ constexpr const char* const
    167     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
    168 /* static */ constexpr const char* const
    169     RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
    170 
    171 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
    172     const string& name, ExecutorBuildFunc executor_build_func) {
    173   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
    174   executor_build_registry[name] = std::move(executor_build_func);
    175 }
    176 
    177 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
    178 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
    179   ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
    180   if (executor_build_registry.count(name) <= 0) {
    181     return nullptr;
    182   }
    183   return &executor_build_registry.at(name);
    184 }
    185 
    186 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
    187 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
    188   static ExecutorBuildRegistry executor_builder_registry;
    189   return &executor_builder_registry;
    190 }
    191 
    192 /**
    193  * - DryRunInference
    194  * To determine shapes of output tensors of all nodes, dryrun the graph.
    195  * This function supplies memory allocation information when loading
    196  * the graph. This function is used to verify shape inference and actual
    197  * output shape.
    198  */
    199 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
    200     const GraphDef& graph_def,
    201     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    202     const std::vector<string>& output_node_names, const bool initialize_by_zero,
    203     std::vector<tensorflow::Tensor>* output_tensors) {
    204   // Create input tensor vector.  If "initialize_by_zero" is true,
    205   // input tensor fields are initialized by 0.
    206   std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
    207   for (const std::pair<string, Tensor>& input : input_node_info_list) {
    208     CHECK(input.second.IsInitialized());
    209     if (!initialize_by_zero) {
    210       input_tensors.push_back({input.first, input.second});
    211       continue;
    212     }
    213     // If input tensor is not initialized, initialize by 0-filling
    214     const DataType data_type = input.second.dtype();
    215     const TensorShape& shape = input.second.shape();
    216     Tensor input_tensor(data_type, shape);
    217     switch (data_type) {
    218       case DT_INT32: {
    219         auto int_tensor = input_tensor.flat<int32>();
    220         int_tensor = int_tensor.constant(0);
    221         break;
    222       }
    223       case DT_FLOAT: {
    224         auto float_tensor = input_tensor.flat<float>();
    225         float_tensor = float_tensor.constant(0.0f);
    226         break;
    227       }
    228       case DT_QUINT8: {
    229         auto int_tensor = input_tensor.flat<quint8>();
    230         int_tensor = int_tensor.constant(0);
    231         break;
    232       }
    233       default:
    234         LOG(FATAL) << "Unsupported input type: " << data_type;
    235     }
    236     input_tensors.push_back({input.first, input_tensor});
    237   }
    238 
    239   // Setup session
    240   CHECK(output_tensors != nullptr);
    241   SessionOptions session_options;
    242   session_options.env = Env::Default();
    243   std::unique_ptr<Session> session =
    244       std::unique_ptr<Session>(NewSession(session_options));
    245   Status status = session->Create(graph_def);
    246   if (!status.ok()) {
    247     return status;
    248   }
    249 
    250   // Setup session arguments
    251   RunOptions run_options;
    252   run_options.set_trace_level(RunOptions::FULL_TRACE);
    253   RunMetadata run_metadata;
    254 
    255   // Run inference with all node as output
    256   status = session->Run(run_options, input_tensors, output_node_names, {},
    257                         output_tensors, &run_metadata);
    258   if (!status.ok()) {
    259     LOG(ERROR) << "Error during inference: " << status;
    260     return status;
    261   }
    262   return Status();
    263 }
    264 
    265 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
    266     const GraphDef& graph_def,
    267     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    268     const bool initialize_by_zero,
    269     RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
    270   CHECK(tensor_shape_map != nullptr);
    271   std::vector<Tensor> output_tensors;
    272   output_tensors.reserve(graph_def.node_size());
    273   std::vector<string> output_node_names;
    274 
    275   Graph graph(OpRegistry::Global());
    276   Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
    277   if (!status.ok()) {
    278     return status;
    279   }
    280 
    281   for (const Node* node : graph.nodes()) {
    282     if (IsInputNode(input_node_info_list, node->name())) {
    283       continue;
    284     }
    285     for (int i = 0; i < node->num_outputs(); ++i) {
    286       output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
    287     }
    288   }
    289 
    290   status = DryRunInference(graph_def, input_node_info_list, output_node_names,
    291                            initialize_by_zero, &output_tensors);
    292   if (!status.ok()) {
    293     VLOG(1) << "Failed to dryrun " << status;
    294     return status;
    295   }
    296 
    297   CHECK_EQ(output_node_names.size(), output_tensors.size())
    298       << output_node_names.size() << ", " << output_tensors.size();
    299 
    300   // Append output tensor of input node in advance to create a map
    301   // to avoid memory reallocation inside vector
    302   for (const std::pair<string, Tensor>& input_node_info :
    303        input_node_info_list) {
    304     output_tensors.push_back(input_node_info.second);
    305   }
    306 
    307   for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
    308     const string& name = output_node_names.at(i);
    309     const Tensor& tensor = output_tensors.at(i);
    310     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
    311   }
    312   for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
    313     const string& name = input_node_info_list.at(i).first;
    314     const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
    315     EmplaceTensorShapeType(name, tensor, tensor_shape_map);
    316   }
    317   CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
    318            output_tensors.size());
    319   return status;
    320 }
    321 
    322 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
    323     const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
    324     const string& node_name) {
    325   for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
    326     const TensorId tid = ParseTensorName(pair.first);
    327     if (node_name == tid.first.ToString()) {
    328       return true;
    329     }
    330   }
    331   return false;
    332 }
    333 
    334 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
    335     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    336     const std::vector<string>& output_node_names,
    337     const std::vector<tensorflow::Tensor>& output_tensors,
    338     TensorShapeMap* tensor_shape_map) {
    339   CHECK_NE(tensor_shape_map, nullptr);
    340   tensor_shape_map->clear();
    341   tensor_shape_map->reserve(input_node_info_list.size() +
    342                             output_node_names.size());
    343   const int output_node_count = output_node_names.size();
    344   CHECK_EQ(output_node_count, output_tensors.size());
    345   for (int i = 0; i < output_node_count; ++i) {
    346     const string& node_name = output_node_names.at(i);
    347     const Tensor& tensor = output_tensors.at(i);
    348     EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
    349   }
    350 }
    351 
    352 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
    353     const TensorProto& tensor_proto, Tensor* tensor) {
    354   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
    355     Tensor parsed(tensor_proto.dtype());
    356     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
    357       *tensor = parsed;
    358       return Status::OK();
    359     }
    360   }
    361   return errors::InvalidArgument("Cannot parse tensor from proto");
    362 }
    363 
    364 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
    365     const std::vector<DataType>& data_types,
    366     const std::vector<TensorShape>& shapes, NodeDef* node_def) {
    367   AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
    368   AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
    369   return true;
    370 }
    371 
    372 /* static */ Status
    373 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
    374     const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
    375   CHECK_NE(node_def, nullptr);
    376   std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
    377   auto its = tensor_shape_map.equal_range(node_def->name());
    378   for (auto it = its.first; it != its.second; ++it) {
    379     queue.emplace(std::make_tuple(it->second.first, &it->second.second));
    380   }
    381   int last_port = queue.size();
    382   std::vector<DataType> data_types;
    383   std::vector<TensorShape> shapes;
    384   while (!queue.empty()) {
    385     const int port = std::get<0>(queue.top());
    386     const TensorShapeType* tst = std::get<1>(queue.top());
    387     CHECK_NE(tst, nullptr);
    388     data_types.emplace(data_types.begin(), tst->first);
    389     shapes.emplace(shapes.begin(), tst->second);
    390     CHECK_EQ(last_port - 1, port);
    391     last_port = port;
    392     queue.pop();
    393   }
    394   AddOutputTensorShapeType(data_types, shapes, node_def);
    395   return Status::OK();
    396 }
    397 
    398 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    399     AttrSlice attrs, std::vector<DataType>* data_types,
    400     std::vector<TensorShape>* shapes) {
    401   Status status;
    402   if (data_types != nullptr) {
    403     status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
    404   }
    405   if (!status.ok()) {
    406     return status;
    407   }
    408   if (shapes != nullptr) {
    409     status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
    410     if (status.ok() && data_types != nullptr) {
    411       CHECK_EQ(data_types->size(), shapes->size());
    412     }
    413   }
    414 
    415   return status;
    416 }
    417 
    418 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
    419     const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
    420     TensorShape* shape) {
    421   std::vector<DataType> data_types;
    422   std::vector<TensorShape> shapes;
    423   const TensorId tid = ParseTensorName(name_and_port);
    424   const string node_name = tid.first.ToString();
    425   const int port = tid.second;
    426   const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
    427   CHECK_NOTNULL(node_def);
    428   GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
    429   if (data_types.empty()) {
    430     return false;
    431   }
    432   CHECK(data_types.size() > port);
    433   *data_type = data_types.at(port);
    434   *shape = shapes.at(port);
    435   return true;
    436 }
    437 
    438 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
    439     const GraphDef& graph_def,
    440     const std::vector<std::pair<string, Tensor>>& input_node_info_list,
    441     Graph* graph, ShapeRefiner* shape_refiner) {
    442   Status status;
    443   auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
    444     if (!status.ok()) {
    445       return;
    446     }
    447     CHECK_NE(node, nullptr);
    448     // If we visit an input node, we use the shape provided and set the
    449     // shape accordingly.
    450     bool is_input_node = false;
    451     for (const std::pair<string, Tensor>& input_node_info :
    452          input_node_info_list) {
    453       if (node->name() == input_node_info.first) {
    454         shape_inference::InferenceContext* context =
    455             shape_refiner->GetContext(node);
    456         shape_inference::ShapeHandle handle;
    457         status = context->MakeShapeFromTensorShape(
    458             input_node_info.second.shape(), &handle);
    459         if (!status.ok()) {
    460           break;
    461         }
    462         status = shape_refiner->SetShape(node, 0, handle);
    463         if (!status.ok()) {
    464           break;
    465         }
    466         is_input_node = true;
    467       }
    468       if (!status.ok()) {
    469         break;
    470       }
    471     }
    472     // If not an input node call AddNode() that recomputes the shape.
    473     if (!is_input_node && status.ok()) {
    474       status = shape_refiner->AddNode(node);
    475     }
    476     if (!status.ok()) {
    477       VLOG(1) << "Shape inference failed for node: " << node->name();
    478     }
    479   };
    480 
    481   ReverseDFS(*graph, {}, visit);
    482 
    483   return status;
    484 }
    485 
    486 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
    487     const Graph& graph, const ShapeRefiner& shape_refiner,
    488     TensorShapeMap* tensor_shape_map) {
    489   for (int i = 0; i < graph.num_node_ids(); ++i) {
    490     const Node* node = graph.FindNodeId(i);
    491     CHECK_NE(node, nullptr);
    492     for (int j = 0; j < node->num_outputs(); ++j) {
    493       const int output_index = j;
    494       const DataType dt = node->output_type(output_index);
    495       shape_inference::InferenceContext* context =
    496           shape_refiner.GetContext(node);
    497       CHECK_NE(context, nullptr);
    498       shape_inference::ShapeHandle shape_handle = context->output(output_index);
    499       if (context->RankKnown(shape_handle)) {
    500         TensorShape ts;
    501         for (int k = 0; k < context->Rank(shape_handle); ++k) {
    502           shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
    503           CHECK(context->ValueKnown(dh));
    504           ts.AddDim(context->Value(dh));
    505         }
    506         const string& node_name = node->name();
    507         CHECK(tensor_shape_map->count(node_name) == 0);
    508         tensor_shape_map->emplace(node_name,
    509                                   std::make_pair(j, std::make_pair(dt, ts)));
    510       } else {
    511         return errors::InvalidArgument("Graph contains unknow shapes");
    512       }
    513     }
    514   }
    515   return Status::OK();
    516 }
    517 
    518 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
    519 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
    520     const TensorShapeMap& tensor_shape_map, const string& node_name) {
    521   if (node_name.find(':') != string::npos) {
    522     const TensorId tid = ParseTensorName(node_name);
    523     return GetTensorShapeType(tensor_shape_map, tid.first.ToString(),
    524                               tid.second);
    525   } else {
    526     return GetTensorShapeType(tensor_shape_map, node_name, 0);
    527   }
    528 }
    529 
    530 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
    531 RemoteFusedGraphExecuteUtils::GetTensorShapeType(
    532     const TensorShapeMap& tensor_shape_map, const string& node_name,
    533     const int port) {
    534   CHECK_EQ(node_name.find(':'), string::npos);
    535   if (tensor_shape_map.count(node_name) <= 0) {
    536     return nullptr;
    537   }
    538   auto its = tensor_shape_map.equal_range(node_name);
    539   for (auto it = its.first; it != its.second; ++it) {
    540     if (it->second.first == port) {
    541       return &it->second.second;
    542     }
    543   }
    544   return nullptr;
    545 }
    546 
    547 /* static */ void
    548 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
    549     const RemoteFusedGraphExecuteInfo& proto,
    550     std::vector<std::pair<string, Tensor>>* inputs,
    551     std::vector<string>* outputs) {
    552   CHECK_EQ(proto.graph_input_node_name_size(),
    553            proto.default_graph_input_tensor_shape_size());
    554   for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
    555     inputs->emplace_back(
    556         proto.graph_input_node_name(i),
    557         Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
    558                TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
    559   }
    560   for (const string& output_node_name : proto.graph_output_node_name()) {
    561     outputs->emplace_back(output_node_name);
    562   }
    563 }
    564 
    565 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
    566     const string& name, const Tensor& tensor,
    567     TensorShapeMap* tensor_shape_map) {
    568   const TensorId tid = ParseTensorName(name);
    569   CHECK_EQ(tensor_shape_map->count(name), 0);
    570   tensor_shape_map->emplace(
    571       tid.first.ToString(),
    572       std::make_pair(tid.second,
    573                      std::make_pair(tensor.dtype(), tensor.shape())));
    574 }
    575 
    576 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
    577     const std::vector<std::pair<string, Tensor>>& input_tensors,
    578     const bool dry_run_inference, GraphDef* graph_def) {
    579   TensorShapeMap tensor_shape_map;
    580   if (dry_run_inference) {
    581     TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
    582                                                  /*initialize_by_zero=*/true,
    583                                                  &tensor_shape_map));
    584   } else {
    585     ImportGraphDefOptions opts;
    586     Graph graph(OpRegistry::Global());
    587     ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    588     TF_RETURN_IF_ERROR(
    589         ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
    590     TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
    591                                                &graph, &shape_refiner));
    592     TF_RETURN_IF_ERROR(
    593         BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
    594   }
    595 
    596   for (NodeDef& node_def : *graph_def->mutable_node()) {
    597     TF_RETURN_IF_ERROR(
    598         AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
    599   }
    600 
    601   return Status::OK();
    602 }
    603 
    604 /* static */ Status
    605 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
    606     const string& executor_name, const GraphDef& subgraph_def,
    607     const std::vector<string>& inputs, const std::vector<string>& outputs,
    608     const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
    609     DataTypeVector* input_types, DataTypeVector* output_types) {
    610   CHECK_NOTNULL(execute_info);
    611   CHECK_NOTNULL(input_types);
    612   CHECK_NOTNULL(output_types);
    613 
    614   execute_info->Clear();
    615   execute_info->set_executor_name(executor_name);
    616 
    617   // copy graph
    618   *execute_info->mutable_remote_graph() = subgraph_def;
    619 
    620   for (const string& input : inputs) {
    621     DataType dt;
    622     TensorShape shape;
    623     const bool has_shapetype =
    624         GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
    625 
    626     execute_info->add_graph_input_node_name(input);
    627     if (has_shapetype) {
    628       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
    629           *execute_info->add_default_graph_input_tensor_shape();
    630       tensor_shape_type.set_dtype(dt);
    631       TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
    632       for (const int64 dim : shape.dim_sizes()) {
    633         tensor_shape_proto.add_dim()->set_size(dim);
    634       }
    635       input_types->push_back(dt);
    636     } else {
    637       CHECK(!require_shape_type)
    638           << "No shape type found for " << input << DumpGraphDef(subgraph_def);
    639       // Assuming input type is float if no data provided.
    640       input_types->push_back(DT_FLOAT);
    641     }
    642   }
    643 
    644   for (const string& output : outputs) {
    645     DataType dt;
    646     TensorShape shape;
    647     const bool has_shapetype =
    648         GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
    649 
    650     execute_info->add_graph_output_node_name(output);
    651     if (has_shapetype) {
    652       RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
    653           tensor_shape_type_proto =
    654               *execute_info->add_default_graph_output_tensor_shape();
    655       tensor_shape_type_proto.set_dtype(dt);
    656       TensorShapeProto& tensor_shape_proto =
    657           *tensor_shape_type_proto.mutable_shape();
    658       for (const int64 dim : shape.dim_sizes()) {
    659         tensor_shape_proto.add_dim()->set_size(dim);
    660       }
    661       output_types->push_back(dt);
    662     } else {
    663       CHECK(!require_shape_type)
    664           << "No shape type found for " << output << DumpGraphDef(subgraph_def);
    665       // Assuming output type is float if no data provided.
    666       output_types->push_back(DT_FLOAT);
    667     }
    668   }
    669 
    670   return Status::OK();
    671 }
    672 
    673 /* static */ Status
    674 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
    675     const string& node_name, const string& executor_name,
    676     const GraphDef& subgraph_def, const std::vector<string>& inputs,
    677     const std::vector<string>& outputs, const bool require_shape_type,
    678     Graph* graph, Node** created_node) {
    679   CHECK_NOTNULL(graph);
    680   CHECK_NOTNULL(created_node);
    681 
    682   RemoteFusedGraphExecuteInfo execute_info;
    683   DataTypeVector input_types;
    684   DataTypeVector output_types;
    685 
    686   TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
    687       executor_name, subgraph_def, inputs, outputs, require_shape_type,
    688       &execute_info, &input_types, &output_types));
    689 
    690   std::vector<NodeBuilder::NodeOut> node_out_list;
    691   for (const string& input : inputs) {
    692     const TensorId tid = ParseTensorName(input);
    693     Node* node = FindMutableNodeByName(tid.first.ToString(), graph);
    694     CHECK_NOTNULL(node);
    695     node_out_list.emplace_back(node, tid.second);
    696   }
    697 
    698   const string execute_info_str = execute_info.SerializeAsString();
    699 
    700   auto builder =
    701       NodeBuilder(node_name, "RemoteFusedGraphExecute")
    702           .Input(node_out_list)
    703           .Attr("Tinputs", input_types)
    704           .Attr("Toutputs", output_types)
    705           .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
    706 
    707   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
    708   return Status::OK();
    709 }
    710 
    711 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
    712     const string& node_name, const string& input_node_name,
    713     const int input_node_port, const DataType dt, Graph* graph,
    714     Node** created_node) {
    715   Node* node = FindMutableNodeByName(input_node_name, graph);
    716   CHECK_NOTNULL(node);
    717   NodeBuilder::NodeOut node_out(node, input_node_port);
    718 
    719   auto builder =
    720       NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
    721 
    722   TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
    723   return Status::OK();
    724 }
    725 
    726 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
    727     const std::unordered_set<string>& node_names, const GraphDef& graph_def,
    728     std::vector<ClusterInfo>* cluster_infos) {
    729   Graph graph(OpRegistry::Global());
    730   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    731   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
    732   std::unordered_set<string> remaining_nodes = node_names;
    733 
    734   while (!remaining_nodes.empty()) {
    735     ClusterInfo ci;
    736 
    737     // Determine one cluster nodes
    738     std::unordered_set<const Node*> visited;
    739     std::deque<const Node*> queue;
    740     queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
    741     while (!queue.empty()) {
    742       const Node* node = queue.front();
    743       CHECK_NOTNULL(node);
    744       queue.pop_front();
    745       const string& node_name = node->name();
    746       if (node_names.count(node_name) > 0) {
    747         std::get<0>(ci).emplace(node_name);
    748         remaining_nodes.erase(node_name);
    749       } else {
    750         // Edge of subgraph.  Do nothing.
    751         continue;
    752       }
    753       for (const Node* in : node->in_nodes()) {
    754         if (visited.insert(in).second) {
    755           queue.push_back(in);
    756         }
    757       }
    758       for (const Node* out : node->out_nodes()) {
    759         if (visited.insert(out).second) {
    760           queue.push_back(out);
    761         }
    762       }
    763     }
    764 
    765     // Determine one cluster border
    766     std::vector<string>& border_inputs = std::get<1>(ci);
    767     std::vector<string>& border_outputs = std::get<2>(ci);
    768     for (const string& node_name : node_names) {
    769       Node* node = FindMutableNodeByName(node_name, &graph);
    770       CHECK_NOTNULL(node);
    771       int input_count = 0;
    772       for (const Edge* in_edge : node->in_edges()) {
    773         const Node* src_node = in_edge->src();
    774         const bool src_is_outside =
    775             node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
    776         if (src_is_outside) {
    777           const string src_name =
    778               strings::StrCat(src_node->name(), ":", in_edge->src_output());
    779           CHECK_EQ(1, src_node->num_outputs())
    780               << "output count of input border node must be one."
    781               << src_node->name();
    782           if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
    783               border_inputs.end()) {
    784             border_inputs.emplace_back(src_name);
    785           }
    786         } else {
    787           ++input_count;
    788         }
    789       }
    790       CHECK(input_count == 0 || input_count == node->in_edges().size())
    791           << "Invalid input_count(" << input_count << ", "
    792           << node->in_edges().size() << ") " << node_name;
    793 
    794       for (const Edge* out_edge : node->out_edges()) {
    795         const Node* dst_node = out_edge->dst();
    796         CHECK_NOTNULL(dst_node);
    797         const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
    798         const string dst_name =
    799             strings::StrCat(node->name(), ":", out_edge->src_output());
    800         if (dst_is_outside) {
    801           if (dst_node->IsSink()) {
    802             CHECK_EQ(1, node->num_outputs())
    803                 << "If you want to specify output node as subgraph output node "
    804                 << "the output count of the node must be 1 "
    805                 << "because that node is replaced by identity node.";
    806             const string identity_dst_name =
    807                 strings::StrCat(node->name(), ":", 0);
    808             if (std::find(border_outputs.begin(), border_outputs.end(),
    809                           identity_dst_name) == border_outputs.end()) {
    810               border_outputs.emplace_back(identity_dst_name);
    811             }
    812           } else {
    813             if (std::find(border_outputs.begin(), border_outputs.end(),
    814                           dst_name) == border_outputs.end()) {
    815               border_outputs.emplace_back(dst_name);
    816             }
    817           }
    818         }
    819       }
    820     }
    821     cluster_infos->emplace_back(ci);
    822     VLOG(1) << DumpCluster(ci);
    823   }
    824   return Status::OK();
    825 }
    826 
    827 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    828     const ClusterInfo& cluster, const GraphDef& graph_def,
    829     GraphDef* subgraph_def) {
    830   const std::unordered_set<string>& node_names = std::get<0>(cluster);
    831   const std::unordered_set<string>& border_input_names =
    832       BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
    833 
    834   Graph graph(OpRegistry::Global());
    835   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    836   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
    837 
    838   for (Node* node : graph.nodes()) {
    839     if (node != nullptr && node_names.count(node->name()) <= 0 &&
    840         border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
    841         !node->IsSink()) {
    842       graph.RemoveNode(node);
    843     }
    844   }
    845   graph.ToGraphDef(subgraph_def);
    846 
    847   for (const string& subgraph_input : std::get<1>(cluster)) {
    848     const TensorId tid = ParseTensorName(subgraph_input);
    849     const string subgraph_input_name = tid.first.ToString();
    850     const int subgraph_input_port = tid.second;
    851     const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
    852     CHECK_NOTNULL(node_def);
    853     std::vector<DataType> dt_vec;
    854     std::vector<TensorShape> shape_vec;
    855     GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
    856     const DataType& dt =
    857         dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
    858     const TensorShape& shape =
    859         shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
    860 
    861     TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
    862                                                      shape, subgraph_def));
    863   }
    864 
    865   // sort subgraph_def to align order in graph_def
    866   std::unordered_map<string, int> name_to_id_map;
    867   for (int i = 0; i < graph_def.node_size(); ++i) {
    868     name_to_id_map.emplace(graph_def.node(i).name(), i);
    869   }
    870   std::sort(subgraph_def->mutable_node()->begin(),
    871             subgraph_def->mutable_node()->end(),
    872             [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
    873               CHECK(name_to_id_map.count(node0.name()) > 0);
    874               CHECK(name_to_id_map.count(node1.name()) > 0);
    875               const int id0 = name_to_id_map.at(node0.name());
    876               const int id1 = name_to_id_map.at(node1.name());
    877               return id0 < id1;
    878             });
    879 
    880   VLOG(1) << DumpGraphDef(*subgraph_def);
    881   return Status::OK();
    882 }
    883 
    884 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    885     const std::vector<string>& border_inputs,
    886     const std::vector<string>& border_outputs, const GraphDef& graph_def,
    887     ClusterInfo* cluster) {
    888   Graph graph(OpRegistry::Global());
    889   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    890   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
    891 
    892   std::unordered_set<const Node*> visited;
    893   std::deque<const Node*> queue;
    894   for (const string& output : border_outputs) {
    895     const TensorId tid = ParseTensorName(output);
    896     const string& output_node_name = tid.first.ToString();
    897     for (const Node* node : graph.nodes()) {
    898       if (output_node_name == node->name()) {
    899         queue.push_back(node);
    900         visited.insert(node);
    901       }
    902     }
    903   }
    904 
    905   std::unordered_set<const Node*> border_input_nodes;
    906   // propagate visit to parent nodes until input nodes
    907   while (!queue.empty()) {
    908     const Node* node = queue.front();
    909     queue.pop_front();
    910     for (const Edge* edge : node->in_edges()) {
    911       const Node* src_node = edge->src();
    912       CHECK_NOTNULL(src_node);
    913       const int src_port = edge->src_output();
    914       bool input_found = false;
    915       for (const string& input : border_inputs) {
    916         const TensorId tid = ParseTensorName(input);
    917         if (tid.first.ToString() == src_node->name() &&
    918             tid.second == src_port) {
    919           input_found = true;
    920           border_input_nodes.insert(src_node);
    921         }
    922       }
    923       if (visited.insert(src_node).second) {
    924         if (!input_found) {
    925           queue.push_back(src_node);
    926         }
    927       }
    928     }
    929   }
    930 
    931   for (const Node* node : visited) {
    932     if (node != nullptr && !node->IsSource() && !node->IsSink() &&
    933         border_input_nodes.count(node) <= 0) {
    934       std::get<0>(*cluster).insert(node->name());
    935     }
    936   }
    937   std::get<1>(*cluster) = border_inputs;
    938   std::get<2>(*cluster) = border_outputs;
    939   return Status::OK();
    940 }
    941 
    942 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
    943     const GraphDef& input_graph_def, const std::vector<string>& inputs,
    944     const std::vector<string>& outputs,
    945     const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
    946     const string& remote_graph_executor_name, const bool require_shape_type,
    947     GraphDef* output_graph_def) {
    948   LOG(INFO) << "Transforming quantized stripped model to a remote fused "
    949                "graph execute op by fusing a specified subgraph...";
    950 
    951   CHECK(!remote_graph_executor_name.empty());
    952 
    953   const std::vector<string>& border_inputs = std::get<1>(cluster);
    954   const std::vector<string>& border_outputs = std::get<2>(cluster);
    955 
    956   GraphDef subgraph_def;
    957   TF_RETURN_IF_ERROR(
    958       BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
    959 
    960   Graph graph(OpRegistry::Global());
    961   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    962   TF_RETURN_IF_ERROR(
    963       ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
    964 
    965   Node* fused_node;
    966   TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
    967       remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
    968       border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
    969 
    970   for (const Node* node : graph.nodes()) {
    971     for (int i = 0; i < node->num_inputs(); ++i) {
    972       const Edge* edge = nullptr;
    973       TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
    974       for (int j = 0; j < border_outputs.size(); ++j) {
    975         const string& output = border_outputs.at(j);
    976         const TensorId tid = ParseTensorName(output);
    977         const string output_name = tid.first.ToString();
    978         Node* src_node = edge->src();
    979         if (src_node != nullptr && src_node->name() == output_name &&
    980             edge->src_output() == tid.second) {
    981           // Source node is replaced by new fused node.
    982           Node* dst_node = edge->dst();
    983           const int dst_input = edge->dst_input();
    984           LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
    985                     << " from " << edge->src()->name();
    986           graph.RemoveEdge(edge);
    987           graph.AddEdge(fused_node, j, dst_node, dst_input);
    988         }
    989       }
    990     }
    991   }
    992 
    993   // Replace output nodes by identity nodes which forward outputs from
    994   // RemoteFusedGraphExecuteOpNode
    995   for (const string& output : outputs) {
    996     const TensorId output_tid = ParseTensorName(output);
    997     const string output_name = output_tid.first.ToString();
    998     for (size_t i = 0; i < border_outputs.size(); ++i) {
    999       const TensorId subgraph_output_tid =
   1000           ParseTensorName(border_outputs.at(i));
   1001       const string& subgraph_output_name = subgraph_output_tid.first.ToString();
   1002       if (output_name == subgraph_output_name) {
   1003         LOG(INFO) << "As graph output and subgraph output are same, "
   1004                   << "the graph output node is replaced by identity node";
   1005         Node* original_output_node = FindMutableNodeByName(output_name, &graph);
   1006         CHECK_NOTNULL(original_output_node);
   1007         CHECK_EQ(1, original_output_node->num_outputs())
   1008             << "Num outputs should be 1 for " << output << ".";
   1009         graph.RemoveNode(original_output_node);
   1010         Node* new_node;
   1011         TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
   1012                                                remote_fused_graph_node_name, i,
   1013                                                DT_FLOAT, &graph, &new_node));
   1014         CHECK_NOTNULL(new_node);
   1015       }
   1016     }
   1017   }
   1018 
   1019   GraphDef result_graph_def;
   1020 
   1021   graph.ToGraphDef(&result_graph_def);
   1022 
   1023   ClusterInfo graph_cluster;
   1024   TF_RETURN_IF_ERROR(
   1025       BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
   1026 
   1027   // Remove unvisited nodes
   1028   TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
   1029                                              output_graph_def));
   1030 
   1031   return Status::OK();
   1032 }
   1033 
   1034 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
   1035     const GraphDef& input_graph_def, const std::vector<string>& inputs,
   1036     const std::vector<string>& outputs,
   1037     const string& remote_fused_graph_node_name_prefix,
   1038     const std::unordered_set<string>& subgraph_nodes,
   1039     const string& remote_fused_graph_executor_name,
   1040     const bool require_shape_type, GraphDef* output_graph_def) {
   1041   std::vector<ClusterInfo> ci_vec;
   1042   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
   1043       subgraph_nodes, input_graph_def, &ci_vec));
   1044 
   1045   for (size_t i = 0; i < ci_vec.size(); ++i) {
   1046     const string remote_fused_graph_node_name =
   1047         strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
   1048     TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
   1049                                    remote_fused_graph_node_name, ci_vec.at(i),
   1050                                    remote_fused_graph_executor_name,
   1051                                    require_shape_type, output_graph_def));
   1052   }
   1053   return Status::OK();
   1054 }
   1055 
   1056 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
   1057     const GraphDef& input_graph_def, const std::vector<string>& inputs,
   1058     const std::vector<string>& outputs,
   1059     const string& remote_fused_graph_node_name,
   1060     const std::vector<string>& border_inputs,
   1061     const std::vector<string>& border_outputs,
   1062     const string& remote_graph_executor_name, const bool require_shape_type,
   1063     GraphDef* output_graph_def) {
   1064   ClusterInfo cluster;
   1065   TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
   1066       border_inputs, border_outputs, input_graph_def, &cluster));
   1067 
   1068   return FuseCluster(
   1069       input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
   1070       remote_graph_executor_name, require_shape_type, output_graph_def);
   1071 }
   1072 
   1073 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
   1074     const GraphDef& input_graph_def, const std::vector<string>& inputs,
   1075     const std::vector<string>& outputs,
   1076     const string& remote_fused_graph_node_name_prefix,
   1077     const std::unordered_set<string>& fused_op_types,
   1078     const string& remote_fused_graph_executor_name,
   1079     const bool require_shape_type, GraphDef* output_graph_def) {
   1080   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
   1081       BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
   1082 
   1083   return FuseRemoteGraphByNodeNames(
   1084       input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
   1085       fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
   1086       require_shape_type, output_graph_def);
   1087 }
   1088 
   1089 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
   1090     const GraphDef& input_graph_def, const std::vector<string>& inputs,
   1091     const std::vector<string>& outputs, const string& executor_name,
   1092     GraphDef* output_graph_def) {
   1093   const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
   1094   if (build_func == nullptr) {
   1095     return errors::InvalidArgument("Unknown executor name: " + executor_name);
   1096   }
   1097   std::unique_ptr<IRemoteFusedGraphExecutor> executor;
   1098   TF_RETURN_IF_ERROR((*build_func)(&executor));
   1099   CHECK_NOTNULL(executor.get());
   1100   if (!executor->IsEnabled()) {
   1101     // As this executor is not enabled, just return original graph as is.
   1102     *output_graph_def = input_graph_def;
   1103     return Status::OK();
   1104   }
   1105   return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
   1106                                    output_graph_def);
   1107 }
   1108 
   1109 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
   1110     const std::vector<string>& inputs, const std::vector<string>& outputs,
   1111     const std::unordered_set<string>& fused_node_names,
   1112     const std::vector<string>& border_inputs,
   1113     const std::vector<string>& border_outputs,
   1114     const std::unordered_set<string>& fused_op_types,
   1115     const string& remote_fused_graph_node_name,
   1116     const string& remote_graph_executor_name, GraphDef* graph_def) {
   1117   CHECK_NOTNULL(graph_def);
   1118 
   1119   const std::unordered_set<string> fused_nodes_filtered_by_op_types =
   1120       BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
   1121 
   1122   for (NodeDef& node_def : *graph_def->mutable_node()) {
   1123     string attr_str;
   1124     TensorId tid;
   1125     for (size_t i = 0; i < inputs.size(); ++i) {
   1126       if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
   1127         AppendDeliminator(&attr_str);
   1128         attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT,
   1129                                       tid.second, i, remote_graph_executor_name,
   1130                                       remote_fused_graph_node_name);
   1131       }
   1132     }
   1133     for (size_t i = 0; i < outputs.size(); ++i) {
   1134       if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
   1135         AppendDeliminator(&attr_str);
   1136         attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT,
   1137                                       tid.second, i);
   1138       }
   1139     }
   1140     for (const string& fused_node_name : fused_node_names) {
   1141       if (fused_node_name == node_def.name()) {
   1142         AppendDeliminator(&attr_str);
   1143         attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
   1144       }
   1145     }
   1146     for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
   1147       if (fused_node_name == node_def.name()) {
   1148         AppendDeliminator(&attr_str);
   1149         attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
   1150       }
   1151     }
   1152     for (size_t i = 0; i < border_inputs.size(); ++i) {
   1153       if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
   1154         AppendDeliminator(&attr_str);
   1155         attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT,
   1156                                       tid.second, i);
   1157       }
   1158     }
   1159     for (size_t i = 0; i < border_outputs.size(); ++i) {
   1160       if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
   1161         AppendDeliminator(&attr_str);
   1162         attr_str += BuildNodeTypeAttr(
   1163             RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i);
   1164       }
   1165     }
   1166     if (attr_str.empty()) {
   1167       attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED);
   1168     }
   1169     AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
   1170   }
   1171   return Status::OK();
   1172 }
   1173 
   1174 /* static */ Status
   1175 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
   1176     const GraphDef& input_graph_def,
   1177     const std::vector<std::pair<string, Tensor>>& input_tensors,
   1178     GraphDef* output_graph_def) {
   1179   std::unordered_map<int, string> input_map;
   1180   std::unordered_map<int, string> output_map;
   1181   std::unordered_set<string> fused_node_names;
   1182   std::unordered_map<int, string> border_input_map;
   1183   std::unordered_map<int, string> border_output_map;
   1184   string remote_graph_executor_name;
   1185   string remote_fused_graph_node_name;
   1186 
   1187   for (const NodeDef& node_def : input_graph_def.node()) {
   1188     string attr_str;
   1189     TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
   1190     std::vector<std::vector<string>> attr_strs;
   1191     for (const string& str : str_util::Split(attr_str, ":")) {
   1192       attr_strs.emplace_back(str_util::Split(str, ","));
   1193     }
   1194     if (attr_strs.empty()) {
   1195       return errors::InvalidArgument("Remote graph node type not found.");
   1196     }
   1197     for (const std::vector<string>& attr : attr_strs) {
   1198       if (attr.empty()) {
   1199         return errors::InvalidArgument("Empty remote graph node type attr.");
   1200       }
   1201       int node_type_int;
   1202       CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
   1203       const RemoteFusedGraphExecuteInfo::NodeType node_type =
   1204           static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int);
   1205       const string& name = node_def.name();
   1206       int port;
   1207       int index;
   1208 
   1209       switch (node_type) {
   1210         case RemoteFusedGraphExecuteInfo::GRAPH_INPUT:
   1211           VLOG(2) << "Graph input: " << name;
   1212           CHECK_EQ(5, attr.size());
   1213           CHECK(strings::safe_strto32(attr.at(1), &port));
   1214           CHECK(strings::safe_strto32(attr.at(2), &index));
   1215           CHECK(!attr.at(3).empty());
   1216           remote_graph_executor_name = attr.at(3);
   1217           CHECK(!attr.at(4).empty());
   1218           remote_fused_graph_node_name = attr.at(4);
   1219           input_map.emplace(index, strings::StrCat(name, ":", port));
   1220           if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
   1221             LOG(INFO) << "Executor for " << remote_graph_executor_name
   1222                       << " not registered.  Do not fuse.";
   1223             *output_graph_def = input_graph_def;
   1224             return Status::OK();
   1225           }
   1226           break;
   1227         case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT:
   1228           VLOG(2) << "Graph output: " << name;
   1229           CHECK_EQ(3, attr.size());
   1230           CHECK(strings::safe_strto32(attr.at(1), &port));
   1231           CHECK(strings::safe_strto32(attr.at(2), &index));
   1232           output_map.emplace(index, strings::StrCat(name, ":", port));
   1233           break;
   1234         case RemoteFusedGraphExecuteInfo::FUSED_NODE:
   1235           VLOG(2) << "Fused node: " << name;
   1236           CHECK_EQ(1, attr.size());
   1237           fused_node_names.emplace(name);
   1238           break;
   1239         case RemoteFusedGraphExecuteInfo::BORDER_INPUT:
   1240           VLOG(2) << "Border input: " << name;
   1241           CHECK_EQ(3, attr.size());
   1242           CHECK(strings::safe_strto32(attr.at(1), &port));
   1243           CHECK(strings::safe_strto32(attr.at(2), &index));
   1244           border_input_map.emplace(index, strings::StrCat(name, ":", port));
   1245           break;
   1246         case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT:
   1247           VLOG(2) << "Border output: " << name;
   1248           CHECK_EQ(3, attr.size());
   1249           CHECK(strings::safe_strto32(attr.at(1), &port));
   1250           CHECK(strings::safe_strto32(attr.at(2), &index));
   1251           border_output_map.emplace(index, strings::StrCat(name, ":", port));
   1252           break;
   1253         case RemoteFusedGraphExecuteInfo::UNUSED:
   1254           // do nothing
   1255           break;
   1256         default:
   1257           // unsupported value
   1258           LOG(FATAL);
   1259       }
   1260     }
   1261   }
   1262   bool require_shape_type = false;
   1263   std::vector<string> inputs;
   1264   std::vector<string> outputs;
   1265   std::vector<string> border_inputs;
   1266   std::vector<string> border_outputs;
   1267   ConvertMapToVector(input_map, &inputs);
   1268   ConvertMapToVector(output_map, &outputs);
   1269   ConvertMapToVector(border_input_map, &border_inputs);
   1270   ConvertMapToVector(border_output_map, &border_outputs);
   1271 
   1272   if (!input_tensors.empty()) {
   1273     bool input_match = false;
   1274     if (inputs.size() == input_tensors.size()) {
   1275       for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
   1276         if (!ContainsSameTensorId(input_tensor.first, inputs)) {
   1277           break;
   1278         }
   1279         DataType data_type;
   1280         TensorShape shape;
   1281         if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
   1282                                      &data_type, &shape)) {
   1283           if (data_type == input_tensor.second.dtype() &&
   1284               shape == input_tensor.second.shape()) {
   1285             VLOG(2) << "Input matched!";
   1286             // Shape type matched.
   1287             input_match = true;
   1288             require_shape_type = true;
   1289           }
   1290         } else {
   1291           // Shape type not required.
   1292           input_match = true;
   1293         }
   1294       }
   1295     }
   1296     if (!input_match) {
   1297       // Input mismatch.  Just copy original graph
   1298       *output_graph_def = input_graph_def;
   1299       return Status::OK();
   1300     }
   1301   }
   1302 
   1303   if (!fused_node_names.empty()) {
   1304     TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
   1305         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
   1306         fused_node_names, remote_graph_executor_name, require_shape_type,
   1307         output_graph_def));
   1308   } else if (!border_inputs.empty() || !border_outputs.empty()) {
   1309     TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
   1310         input_graph_def, inputs, outputs, remote_fused_graph_node_name,
   1311         border_inputs, border_outputs, remote_graph_executor_name,
   1312         require_shape_type, output_graph_def));
   1313   } else {
   1314     *output_graph_def = input_graph_def;
   1315   }
   1316 
   1317   return Status::OK();
   1318 }
   1319 
   1320 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
   1321     const GraphDef& graph_def,
   1322     const std::vector<std::pair<string, Tensor>>& input_tensors) {
   1323   for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
   1324     const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
   1325     if (node_def == nullptr) {
   1326       return false;
   1327     }
   1328     string attr;
   1329     const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
   1330     if (!status.ok() || attr.empty()) {
   1331       return false;
   1332     }
   1333   }
   1334   return true;
   1335 }
   1336 
   1337 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
   1338     const void* src_ptr, const int src_size, Tensor* tensor) {
   1339   CHECK(tensor->TotalBytes() >= src_size)
   1340       << tensor->TotalBytes() << ", " << src_size;
   1341   void* dst_ptr;
   1342   switch (tensor->dtype()) {
   1343     case DT_FLOAT:
   1344       dst_ptr = tensor->flat<float>().data();
   1345       break;
   1346     case DT_DOUBLE:
   1347       dst_ptr = tensor->flat<double>().data();
   1348       break;
   1349     case DT_INT32:
   1350       dst_ptr = tensor->flat<int32>().data();
   1351       break;
   1352     case DT_UINT8:
   1353       dst_ptr = tensor->flat<uint8>().data();
   1354       break;
   1355     case DT_INT16:
   1356       dst_ptr = tensor->flat<int16>().data();
   1357       break;
   1358     case DT_INT8:
   1359       dst_ptr = tensor->flat<int8>().data();
   1360       break;
   1361     case DT_STRING:
   1362       dst_ptr = tensor->flat<string>().data();
   1363       break;
   1364     case DT_INT64:
   1365       dst_ptr = tensor->flat<int64>().data();
   1366       break;
   1367     case DT_BOOL:
   1368       dst_ptr = tensor->flat<bool>().data();
   1369       break;
   1370     case DT_QINT8:
   1371       dst_ptr = tensor->flat<qint8>().data();
   1372       break;
   1373     case DT_QUINT8:
   1374       dst_ptr = tensor->flat<quint8>().data();
   1375       break;
   1376     case DT_QINT32:
   1377       dst_ptr = tensor->flat<qint32>().data();
   1378       break;
   1379     case DT_BFLOAT16:
   1380       dst_ptr = tensor->flat<bfloat16>().data();
   1381       break;
   1382     case DT_QINT16:
   1383       dst_ptr = tensor->flat<qint16>().data();
   1384       break;
   1385     case DT_QUINT16:
   1386       dst_ptr = tensor->flat<quint16>().data();
   1387       break;
   1388     case DT_UINT16:
   1389       dst_ptr = tensor->flat<uint16>().data();
   1390       break;
   1391     default:
   1392       LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
   1393       break;
   1394   }
   1395   CHECK_NOTNULL(dst_ptr);
   1396   std::memcpy(dst_ptr, src_ptr, src_size);
   1397   return Status::OK();
   1398 }
   1399 
   1400 /* static */ std::unordered_set<string>
   1401 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
   1402     const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
   1403   std::unordered_set<string> retval;
   1404   for (const NodeDef& node_def : graph_def.node()) {
   1405     if (op_types.count(node_def.op()) > 0) {
   1406       retval.emplace(node_def.name());
   1407     }
   1408   }
   1409   return retval;
   1410 }
   1411 
   1412 /* static */ std::unordered_set<string>
   1413 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
   1414     const GraphDef& graph_def,
   1415     const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
   1416   std::unordered_set<string> retval;
   1417   for (const NodeDef& node_def : graph_def.node()) {
   1418     std::vector<DataType> dt_vec;
   1419     std::vector<TensorShape> shape_vec;
   1420     const Status status =
   1421         GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
   1422     if (!status.ok()) {
   1423       shape_vec.clear();
   1424     }
   1425     if (ops_definitions.GetOpIdFor(
   1426             node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
   1427         IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
   1428       retval.emplace(node_def.name());
   1429     }
   1430   }
   1431   return retval;
   1432 }
   1433 
   1434 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
   1435     const string& input, const DataType type, const TensorShape& shape,
   1436     GraphDef* graph_def) {
   1437   const TensorId tid = ParseTensorName(input);
   1438   CHECK_EQ(0, tid.second);
   1439   const string node_name = tid.first.ToString();
   1440   for (NodeDef& node : *graph_def->mutable_node()) {
   1441     if (node.name() != node_name) {
   1442       continue;
   1443     }
   1444     if (node.op() == "Placeholder") {
   1445       return Status::OK();
   1446     } else {
   1447       NodeDef placeholder_node;
   1448       placeholder_node.set_op("Placeholder");
   1449       placeholder_node.set_name(node_name);
   1450       AddNodeAttr("dtype", type, &placeholder_node);
   1451       AddNodeAttr("shape", shape, &placeholder_node);
   1452       // TODO(satok): Remove once we merge attributes
   1453       AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
   1454       node.Clear();
   1455       node = placeholder_node;
   1456       return Status::OK();
   1457     }
   1458   }
   1459   return errors::InvalidArgument(
   1460       strings::StrCat(node_name, " not found for replacement."));
   1461 }
   1462 
   1463 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
   1464     const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
   1465     const int index, const string& executor_name, const string& node_name) {
   1466   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
   1467                          ",", executor_name, ",", node_name);
   1468 }
   1469 
   1470 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
   1471     const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
   1472     const int index) {
   1473   return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
   1474 }
   1475 
   1476 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
   1477     const RemoteFusedGraphExecuteInfo::NodeType node_type) {
   1478   return strings::StrCat(static_cast<int>(node_type));
   1479 }
   1480 
   1481 }  // namespace tensorflow
   1482