1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 17 18 #include <algorithm> 19 #include <queue> 20 #include <utility> 21 22 #include "tensorflow/core/common_runtime/shape_refiner.h" 23 #include "tensorflow/core/framework/node_def_util.h" 24 #include "tensorflow/core/framework/tensor.pb.h" 25 #include "tensorflow/core/framework/tensor_shape.pb.h" 26 #include "tensorflow/core/graph/algorithm.h" 27 #include "tensorflow/core/graph/node_builder.h" 28 #include "tensorflow/core/public/session.h" 29 #include "tensorflow/core/public/session_options.h" 30 31 namespace tensorflow { 32 namespace { 33 const Node* FindNodeByName(const string& name, const Graph& graph) { 34 for (const Node* node : graph.nodes()) { 35 CHECK_NOTNULL(node); 36 if (node->name() == name) { 37 return node; 38 } 39 } 40 return nullptr; 41 } 42 43 std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts( 44 const std::vector<string>& node_names_and_ports) { 45 std::unordered_set<string> retval; 46 for (const string& node_name_and_port : node_names_and_ports) { 47 const TensorId tid = ParseTensorName(node_name_and_port); 48 retval.emplace(tid.first.ToString()); 49 } 50 return retval; 51 } 52 53 Node* FindMutableNodeByName(const string& name, Graph* graph) { 54 for (Node* node : graph->nodes()) { 55 if (node != nullptr && node->name() == name) { 56 return node; 57 } 58 } 59 return nullptr; 60 } 61 62 const NodeDef* FindNodeDefByName(const string& input, 63 const GraphDef& graph_def) { 64 const TensorId tid = ParseTensorName(input); 65 const string name = tid.first.ToString(); 66 for (const NodeDef& node_def : graph_def.node()) { 67 if (node_def.name() == name) { 68 return &node_def; 69 } 70 } 71 return nullptr; 72 } 73 74 bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port, 75 TensorId* tid) { 76 CHECK_NOTNULL(tid); 77 *tid = ParseTensorName(node_name_and_port); 78 if (node_def.name() == tid->first.ToString()) { 79 return true; 80 } 81 return false; 82 } 83 84 bool ContainsSameTensorId(const string& tensor_name, 85 const std::vector<string>& tensor_names) { 86 const TensorId tid0 = ParseTensorName(tensor_name); 87 for (const string& name : tensor_names) { 88 const TensorId tid1 = ParseTensorName(name); 89 if (tid0.first == tid1.first && tid0.second == tid1.second) { 90 return true; 91 } 92 } 93 return false; 94 } 95 96 void AppendDeliminator(string* str) { 97 CHECK_NOTNULL(str); 98 if (!str->empty()) { 99 *str += ":"; 100 } 101 } 102 103 void ConvertMapToVector(const std::unordered_map<int, string>& in, 104 std::vector<string>* out) { 105 CHECK_NOTNULL(out); 106 out->resize(in.size()); 107 for (size_t i = 0; i < in.size(); ++i) { 108 CHECK(in.count(i) > 0); 109 out->at(i) = in.at(i); 110 } 111 } 112 113 string DumpGraphDef(const GraphDef& graph_def) { 114 string out; 115 for (const NodeDef& node : graph_def.node()) { 116 out += strings::StrCat("node: ", node.name(), "\n input: "); 117 for (const string& input : node.input()) { 118 out += strings::StrCat(input, ", "); 119 } 120 out += "\n"; 121 } 122 return out; 123 } 124 125 string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) { 126 string out; 127 out += "Nodes:\n"; 128 for (const string& str : std::get<0>(cluster)) { 129 out += str + ", "; 130 } 131 out += "\nInput border:\n"; 132 for (const string& str : std::get<1>(cluster)) { 133 out += str + ", "; 134 } 135 out += "\nOutput border:\n"; 136 for (const string& str : std::get<2>(cluster)) { 137 out += str + ", "; 138 } 139 return out; 140 } 141 142 } // namespace 143 144 /* static */ constexpr const char* const 145 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES; 146 /* static */ constexpr const char* const 147 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES; 148 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: 149 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO; 150 /* static */ constexpr const char* const 151 RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE; 152 /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: 153 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME; 154 /* static */ constexpr const char* const 155 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME; 156 /* static */ constexpr const char* const 157 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES; 158 /* static */ constexpr const char* const 159 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS; 160 /* static */ constexpr const char* const 161 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS; 162 /* static */ constexpr const char* const 163 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES; 164 /* static */ constexpr const char* const 165 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR; 166 /* static */ constexpr const char* const 167 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES; 168 /* static */ constexpr const char* const 169 RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES; 170 171 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar( 172 const string& name, ExecutorBuildFunc executor_build_func) { 173 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry(); 174 executor_build_registry[name] = std::move(executor_build_func); 175 } 176 177 /* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* 178 RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) { 179 ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry(); 180 if (executor_build_registry.count(name) <= 0) { 181 return nullptr; 182 } 183 return &executor_build_registry.at(name); 184 } 185 186 /* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry* 187 RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { 188 static ExecutorBuildRegistry executor_builder_registry; 189 return &executor_builder_registry; 190 } 191 192 /** 193 * - DryRunInference 194 * To determine shapes of output tensors of all nodes, dryrun the graph. 195 * This function supplies memory allocation information when loading 196 * the graph. This function is used to verify shape inference and actual 197 * output shape. 198 */ 199 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference( 200 const GraphDef& graph_def, 201 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 202 const std::vector<string>& output_node_names, const bool initialize_by_zero, 203 std::vector<tensorflow::Tensor>* output_tensors) { 204 // Create input tensor vector. If "initialize_by_zero" is true, 205 // input tensor fields are initialized by 0. 206 std::vector<std::pair<string, tensorflow::Tensor>> input_tensors; 207 for (const std::pair<string, Tensor>& input : input_node_info_list) { 208 CHECK(input.second.IsInitialized()); 209 if (!initialize_by_zero) { 210 input_tensors.push_back({input.first, input.second}); 211 continue; 212 } 213 // If input tensor is not initialized, initialize by 0-filling 214 const DataType data_type = input.second.dtype(); 215 const TensorShape& shape = input.second.shape(); 216 Tensor input_tensor(data_type, shape); 217 switch (data_type) { 218 case DT_INT32: { 219 auto int_tensor = input_tensor.flat<int32>(); 220 int_tensor = int_tensor.constant(0); 221 break; 222 } 223 case DT_FLOAT: { 224 auto float_tensor = input_tensor.flat<float>(); 225 float_tensor = float_tensor.constant(0.0f); 226 break; 227 } 228 case DT_QUINT8: { 229 auto int_tensor = input_tensor.flat<quint8>(); 230 int_tensor = int_tensor.constant(0); 231 break; 232 } 233 default: 234 LOG(FATAL) << "Unsupported input type: " << data_type; 235 } 236 input_tensors.push_back({input.first, input_tensor}); 237 } 238 239 // Setup session 240 CHECK(output_tensors != nullptr); 241 SessionOptions session_options; 242 session_options.env = Env::Default(); 243 std::unique_ptr<Session> session = 244 std::unique_ptr<Session>(NewSession(session_options)); 245 Status status = session->Create(graph_def); 246 if (!status.ok()) { 247 return status; 248 } 249 250 // Setup session arguments 251 RunOptions run_options; 252 run_options.set_trace_level(RunOptions::FULL_TRACE); 253 RunMetadata run_metadata; 254 255 // Run inference with all node as output 256 status = session->Run(run_options, input_tensors, output_node_names, {}, 257 output_tensors, &run_metadata); 258 if (!status.ok()) { 259 LOG(ERROR) << "Error during inference: " << status; 260 return status; 261 } 262 return Status(); 263 } 264 265 /* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( 266 const GraphDef& graph_def, 267 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 268 const bool initialize_by_zero, 269 RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) { 270 CHECK(tensor_shape_map != nullptr); 271 std::vector<Tensor> output_tensors; 272 output_tensors.reserve(graph_def.node_size()); 273 std::vector<string> output_node_names; 274 275 Graph graph(OpRegistry::Global()); 276 Status status = ImportGraphDef({}, graph_def, &graph, nullptr); 277 if (!status.ok()) { 278 return status; 279 } 280 281 for (const Node* node : graph.nodes()) { 282 if (IsInputNode(input_node_info_list, node->name())) { 283 continue; 284 } 285 for (int i = 0; i < node->num_outputs(); ++i) { 286 output_node_names.emplace_back(strings::StrCat(node->name(), ":", i)); 287 } 288 } 289 290 status = DryRunInference(graph_def, input_node_info_list, output_node_names, 291 initialize_by_zero, &output_tensors); 292 if (!status.ok()) { 293 VLOG(1) << "Failed to dryrun " << status; 294 return status; 295 } 296 297 CHECK_EQ(output_node_names.size(), output_tensors.size()) 298 << output_node_names.size() << ", " << output_tensors.size(); 299 300 // Append output tensor of input node in advance to create a map 301 // to avoid memory reallocation inside vector 302 for (const std::pair<string, Tensor>& input_node_info : 303 input_node_info_list) { 304 output_tensors.push_back(input_node_info.second); 305 } 306 307 for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) { 308 const string& name = output_node_names.at(i); 309 const Tensor& tensor = output_tensors.at(i); 310 EmplaceTensorShapeType(name, tensor, tensor_shape_map); 311 } 312 for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) { 313 const string& name = input_node_info_list.at(i).first; 314 const Tensor& tensor = output_tensors.at(output_node_names.size() + i); 315 EmplaceTensorShapeType(name, tensor, tensor_shape_map); 316 } 317 CHECK_EQ(output_node_names.size() + input_node_info_list.size(), 318 output_tensors.size()); 319 return status; 320 } 321 322 /* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode( 323 const std::vector<std::pair<string, Tensor>>& input_tensor_vector, 324 const string& node_name) { 325 for (const std::pair<string, Tensor>& pair : input_tensor_vector) { 326 const TensorId tid = ParseTensorName(pair.first); 327 if (node_name == tid.first.ToString()) { 328 return true; 329 } 330 } 331 return false; 332 } 333 334 /* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap( 335 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 336 const std::vector<string>& output_node_names, 337 const std::vector<tensorflow::Tensor>& output_tensors, 338 TensorShapeMap* tensor_shape_map) { 339 CHECK_NE(tensor_shape_map, nullptr); 340 tensor_shape_map->clear(); 341 tensor_shape_map->reserve(input_node_info_list.size() + 342 output_node_names.size()); 343 const int output_node_count = output_node_names.size(); 344 CHECK_EQ(output_node_count, output_tensors.size()); 345 for (int i = 0; i < output_node_count; ++i) { 346 const string& node_name = output_node_names.at(i); 347 const Tensor& tensor = output_tensors.at(i); 348 EmplaceTensorShapeType(node_name, tensor, tensor_shape_map); 349 } 350 } 351 352 /* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto( 353 const TensorProto& tensor_proto, Tensor* tensor) { 354 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { 355 Tensor parsed(tensor_proto.dtype()); 356 if (parsed.FromProto(cpu_allocator(), tensor_proto)) { 357 *tensor = parsed; 358 return Status::OK(); 359 } 360 } 361 return errors::InvalidArgument("Cannot parse tensor from proto"); 362 } 363 364 /* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType( 365 const std::vector<DataType>& data_types, 366 const std::vector<TensorShape>& shapes, NodeDef* node_def) { 367 AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def); 368 AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def); 369 return true; 370 } 371 372 /* static */ Status 373 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( 374 const TensorShapeMap& tensor_shape_map, NodeDef* node_def) { 375 CHECK_NE(node_def, nullptr); 376 std::priority_queue<std::tuple<int, const TensorShapeType*>> queue; 377 auto its = tensor_shape_map.equal_range(node_def->name()); 378 for (auto it = its.first; it != its.second; ++it) { 379 queue.emplace(std::make_tuple(it->second.first, &it->second.second)); 380 } 381 int last_port = queue.size(); 382 std::vector<DataType> data_types; 383 std::vector<TensorShape> shapes; 384 while (!queue.empty()) { 385 const int port = std::get<0>(queue.top()); 386 const TensorShapeType* tst = std::get<1>(queue.top()); 387 CHECK_NE(tst, nullptr); 388 data_types.emplace(data_types.begin(), tst->first); 389 shapes.emplace(shapes.begin(), tst->second); 390 CHECK_EQ(last_port - 1, port); 391 last_port = port; 392 queue.pop(); 393 } 394 AddOutputTensorShapeType(data_types, shapes, node_def); 395 return Status::OK(); 396 } 397 398 /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 399 AttrSlice attrs, std::vector<DataType>* data_types, 400 std::vector<TensorShape>* shapes) { 401 Status status; 402 if (data_types != nullptr) { 403 status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types); 404 } 405 if (!status.ok()) { 406 return status; 407 } 408 if (shapes != nullptr) { 409 status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes); 410 if (status.ok() && data_types != nullptr) { 411 CHECK_EQ(data_types->size(), shapes->size()); 412 } 413 } 414 415 return status; 416 } 417 418 /* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 419 const GraphDef& graph_def, const string& name_and_port, DataType* data_type, 420 TensorShape* shape) { 421 std::vector<DataType> data_types; 422 std::vector<TensorShape> shapes; 423 const TensorId tid = ParseTensorName(name_and_port); 424 const string node_name = tid.first.ToString(); 425 const int port = tid.second; 426 const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); 427 CHECK_NOTNULL(node_def); 428 GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError(); 429 if (data_types.empty()) { 430 return false; 431 } 432 CHECK(data_types.size() > port); 433 *data_type = data_types.at(port); 434 *shape = shapes.at(port); 435 return true; 436 } 437 438 /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference( 439 const GraphDef& graph_def, 440 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 441 Graph* graph, ShapeRefiner* shape_refiner) { 442 Status status; 443 auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) { 444 if (!status.ok()) { 445 return; 446 } 447 CHECK_NE(node, nullptr); 448 // If we visit an input node, we use the shape provided and set the 449 // shape accordingly. 450 bool is_input_node = false; 451 for (const std::pair<string, Tensor>& input_node_info : 452 input_node_info_list) { 453 if (node->name() == input_node_info.first) { 454 shape_inference::InferenceContext* context = 455 shape_refiner->GetContext(node); 456 shape_inference::ShapeHandle handle; 457 status = context->MakeShapeFromTensorShape( 458 input_node_info.second.shape(), &handle); 459 if (!status.ok()) { 460 break; 461 } 462 status = shape_refiner->SetShape(node, 0, handle); 463 if (!status.ok()) { 464 break; 465 } 466 is_input_node = true; 467 } 468 if (!status.ok()) { 469 break; 470 } 471 } 472 // If not an input node call AddNode() that recomputes the shape. 473 if (!is_input_node && status.ok()) { 474 status = shape_refiner->AddNode(node); 475 } 476 if (!status.ok()) { 477 VLOG(1) << "Shape inference failed for node: " << node->name(); 478 } 479 }; 480 481 ReverseDFS(*graph, {}, visit); 482 483 return status; 484 } 485 486 /* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph( 487 const Graph& graph, const ShapeRefiner& shape_refiner, 488 TensorShapeMap* tensor_shape_map) { 489 for (int i = 0; i < graph.num_node_ids(); ++i) { 490 const Node* node = graph.FindNodeId(i); 491 CHECK_NE(node, nullptr); 492 for (int j = 0; j < node->num_outputs(); ++j) { 493 const int output_index = j; 494 const DataType dt = node->output_type(output_index); 495 shape_inference::InferenceContext* context = 496 shape_refiner.GetContext(node); 497 CHECK_NE(context, nullptr); 498 shape_inference::ShapeHandle shape_handle = context->output(output_index); 499 if (context->RankKnown(shape_handle)) { 500 TensorShape ts; 501 for (int k = 0; k < context->Rank(shape_handle); ++k) { 502 shape_inference::DimensionHandle dh = context->Dim(shape_handle, k); 503 CHECK(context->ValueKnown(dh)); 504 ts.AddDim(context->Value(dh)); 505 } 506 const string& node_name = node->name(); 507 CHECK(tensor_shape_map->count(node_name) == 0); 508 tensor_shape_map->emplace(node_name, 509 std::make_pair(j, std::make_pair(dt, ts))); 510 } else { 511 return errors::InvalidArgument("Graph contains unknow shapes"); 512 } 513 } 514 } 515 return Status::OK(); 516 } 517 518 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType* 519 RemoteFusedGraphExecuteUtils::GetTensorShapeType( 520 const TensorShapeMap& tensor_shape_map, const string& node_name) { 521 if (node_name.find(':') != string::npos) { 522 const TensorId tid = ParseTensorName(node_name); 523 return GetTensorShapeType(tensor_shape_map, tid.first.ToString(), 524 tid.second); 525 } else { 526 return GetTensorShapeType(tensor_shape_map, node_name, 0); 527 } 528 } 529 530 /* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType* 531 RemoteFusedGraphExecuteUtils::GetTensorShapeType( 532 const TensorShapeMap& tensor_shape_map, const string& node_name, 533 const int port) { 534 CHECK_EQ(node_name.find(':'), string::npos); 535 if (tensor_shape_map.count(node_name) <= 0) { 536 return nullptr; 537 } 538 auto its = tensor_shape_map.equal_range(node_name); 539 for (auto it = its.first; it != its.second; ++it) { 540 if (it->second.first == port) { 541 return &it->second.second; 542 } 543 } 544 return nullptr; 545 } 546 547 /* static */ void 548 RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( 549 const RemoteFusedGraphExecuteInfo& proto, 550 std::vector<std::pair<string, Tensor>>* inputs, 551 std::vector<string>* outputs) { 552 CHECK_EQ(proto.graph_input_node_name_size(), 553 proto.default_graph_input_tensor_shape_size()); 554 for (int i = 0; i < proto.graph_input_node_name_size(); ++i) { 555 inputs->emplace_back( 556 proto.graph_input_node_name(i), 557 Tensor(proto.default_graph_input_tensor_shape(i).dtype(), 558 TensorShape(proto.default_graph_input_tensor_shape(i).shape()))); 559 } 560 for (const string& output_node_name : proto.graph_output_node_name()) { 561 outputs->emplace_back(output_node_name); 562 } 563 } 564 565 /* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType( 566 const string& name, const Tensor& tensor, 567 TensorShapeMap* tensor_shape_map) { 568 const TensorId tid = ParseTensorName(name); 569 CHECK_EQ(tensor_shape_map->count(name), 0); 570 tensor_shape_map->emplace( 571 tid.first.ToString(), 572 std::make_pair(tid.second, 573 std::make_pair(tensor.dtype(), tensor.shape()))); 574 } 575 576 /* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 577 const std::vector<std::pair<string, Tensor>>& input_tensors, 578 const bool dry_run_inference, GraphDef* graph_def) { 579 TensorShapeMap tensor_shape_map; 580 if (dry_run_inference) { 581 TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors, 582 /*initialize_by_zero=*/true, 583 &tensor_shape_map)); 584 } else { 585 ImportGraphDefOptions opts; 586 Graph graph(OpRegistry::Global()); 587 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 588 TF_RETURN_IF_ERROR( 589 ImportGraphDef(opts, *graph_def, &graph, &shape_refiner)); 590 TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors, 591 &graph, &shape_refiner)); 592 TF_RETURN_IF_ERROR( 593 BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map)); 594 } 595 596 for (NodeDef& node_def : *graph_def->mutable_node()) { 597 TF_RETURN_IF_ERROR( 598 AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def)); 599 } 600 601 return Status::OK(); 602 } 603 604 /* static */ Status 605 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo( 606 const string& executor_name, const GraphDef& subgraph_def, 607 const std::vector<string>& inputs, const std::vector<string>& outputs, 608 const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info, 609 DataTypeVector* input_types, DataTypeVector* output_types) { 610 CHECK_NOTNULL(execute_info); 611 CHECK_NOTNULL(input_types); 612 CHECK_NOTNULL(output_types); 613 614 execute_info->Clear(); 615 execute_info->set_executor_name(executor_name); 616 617 // copy graph 618 *execute_info->mutable_remote_graph() = subgraph_def; 619 620 for (const string& input : inputs) { 621 DataType dt; 622 TensorShape shape; 623 const bool has_shapetype = 624 GetOutputTensorShapeType(subgraph_def, input, &dt, &shape); 625 626 execute_info->add_graph_input_node_name(input); 627 if (has_shapetype) { 628 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type = 629 *execute_info->add_default_graph_input_tensor_shape(); 630 tensor_shape_type.set_dtype(dt); 631 TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape(); 632 for (const int64 dim : shape.dim_sizes()) { 633 tensor_shape_proto.add_dim()->set_size(dim); 634 } 635 input_types->push_back(dt); 636 } else { 637 CHECK(!require_shape_type) 638 << "No shape type found for " << input << DumpGraphDef(subgraph_def); 639 // Assuming input type is float if no data provided. 640 input_types->push_back(DT_FLOAT); 641 } 642 } 643 644 for (const string& output : outputs) { 645 DataType dt; 646 TensorShape shape; 647 const bool has_shapetype = 648 GetOutputTensorShapeType(subgraph_def, output, &dt, &shape); 649 650 execute_info->add_graph_output_node_name(output); 651 if (has_shapetype) { 652 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& 653 tensor_shape_type_proto = 654 *execute_info->add_default_graph_output_tensor_shape(); 655 tensor_shape_type_proto.set_dtype(dt); 656 TensorShapeProto& tensor_shape_proto = 657 *tensor_shape_type_proto.mutable_shape(); 658 for (const int64 dim : shape.dim_sizes()) { 659 tensor_shape_proto.add_dim()->set_size(dim); 660 } 661 output_types->push_back(dt); 662 } else { 663 CHECK(!require_shape_type) 664 << "No shape type found for " << output << DumpGraphDef(subgraph_def); 665 // Assuming output type is float if no data provided. 666 output_types->push_back(DT_FLOAT); 667 } 668 } 669 670 return Status::OK(); 671 } 672 673 /* static */ Status 674 RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( 675 const string& node_name, const string& executor_name, 676 const GraphDef& subgraph_def, const std::vector<string>& inputs, 677 const std::vector<string>& outputs, const bool require_shape_type, 678 Graph* graph, Node** created_node) { 679 CHECK_NOTNULL(graph); 680 CHECK_NOTNULL(created_node); 681 682 RemoteFusedGraphExecuteInfo execute_info; 683 DataTypeVector input_types; 684 DataTypeVector output_types; 685 686 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo( 687 executor_name, subgraph_def, inputs, outputs, require_shape_type, 688 &execute_info, &input_types, &output_types)); 689 690 std::vector<NodeBuilder::NodeOut> node_out_list; 691 for (const string& input : inputs) { 692 const TensorId tid = ParseTensorName(input); 693 Node* node = FindMutableNodeByName(tid.first.ToString(), graph); 694 CHECK_NOTNULL(node); 695 node_out_list.emplace_back(node, tid.second); 696 } 697 698 const string execute_info_str = execute_info.SerializeAsString(); 699 700 auto builder = 701 NodeBuilder(node_name, "RemoteFusedGraphExecute") 702 .Input(node_out_list) 703 .Attr("Tinputs", input_types) 704 .Attr("Toutputs", output_types) 705 .Attr("serialized_remote_fused_graph_execute_info", execute_info_str); 706 707 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node)); 708 return Status::OK(); 709 } 710 711 /* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode( 712 const string& node_name, const string& input_node_name, 713 const int input_node_port, const DataType dt, Graph* graph, 714 Node** created_node) { 715 Node* node = FindMutableNodeByName(input_node_name, graph); 716 CHECK_NOTNULL(node); 717 NodeBuilder::NodeOut node_out(node, input_node_port); 718 719 auto builder = 720 NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt); 721 722 TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node)); 723 return Status::OK(); 724 } 725 726 /* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes( 727 const std::unordered_set<string>& node_names, const GraphDef& graph_def, 728 std::vector<ClusterInfo>* cluster_infos) { 729 Graph graph(OpRegistry::Global()); 730 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 731 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 732 std::unordered_set<string> remaining_nodes = node_names; 733 734 while (!remaining_nodes.empty()) { 735 ClusterInfo ci; 736 737 // Determine one cluster nodes 738 std::unordered_set<const Node*> visited; 739 std::deque<const Node*> queue; 740 queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph)); 741 while (!queue.empty()) { 742 const Node* node = queue.front(); 743 CHECK_NOTNULL(node); 744 queue.pop_front(); 745 const string& node_name = node->name(); 746 if (node_names.count(node_name) > 0) { 747 std::get<0>(ci).emplace(node_name); 748 remaining_nodes.erase(node_name); 749 } else { 750 // Edge of subgraph. Do nothing. 751 continue; 752 } 753 for (const Node* in : node->in_nodes()) { 754 if (visited.insert(in).second) { 755 queue.push_back(in); 756 } 757 } 758 for (const Node* out : node->out_nodes()) { 759 if (visited.insert(out).second) { 760 queue.push_back(out); 761 } 762 } 763 } 764 765 // Determine one cluster border 766 std::vector<string>& border_inputs = std::get<1>(ci); 767 std::vector<string>& border_outputs = std::get<2>(ci); 768 for (const string& node_name : node_names) { 769 Node* node = FindMutableNodeByName(node_name, &graph); 770 CHECK_NOTNULL(node); 771 int input_count = 0; 772 for (const Edge* in_edge : node->in_edges()) { 773 const Node* src_node = in_edge->src(); 774 const bool src_is_outside = 775 node_names.count(src_node->name()) <= 0 && !src_node->IsSource(); 776 if (src_is_outside) { 777 const string src_name = 778 strings::StrCat(src_node->name(), ":", in_edge->src_output()); 779 CHECK_EQ(1, src_node->num_outputs()) 780 << "output count of input border node must be one." 781 << src_node->name(); 782 if (std::find(border_inputs.begin(), border_inputs.end(), src_name) == 783 border_inputs.end()) { 784 border_inputs.emplace_back(src_name); 785 } 786 } else { 787 ++input_count; 788 } 789 } 790 CHECK(input_count == 0 || input_count == node->in_edges().size()) 791 << "Invalid input_count(" << input_count << ", " 792 << node->in_edges().size() << ") " << node_name; 793 794 for (const Edge* out_edge : node->out_edges()) { 795 const Node* dst_node = out_edge->dst(); 796 CHECK_NOTNULL(dst_node); 797 const bool dst_is_outside = node_names.count(dst_node->name()) <= 0; 798 const string dst_name = 799 strings::StrCat(node->name(), ":", out_edge->src_output()); 800 if (dst_is_outside) { 801 if (dst_node->IsSink()) { 802 CHECK_EQ(1, node->num_outputs()) 803 << "If you want to specify output node as subgraph output node " 804 << "the output count of the node must be 1 " 805 << "because that node is replaced by identity node."; 806 const string identity_dst_name = 807 strings::StrCat(node->name(), ":", 0); 808 if (std::find(border_outputs.begin(), border_outputs.end(), 809 identity_dst_name) == border_outputs.end()) { 810 border_outputs.emplace_back(identity_dst_name); 811 } 812 } else { 813 if (std::find(border_outputs.begin(), border_outputs.end(), 814 dst_name) == border_outputs.end()) { 815 border_outputs.emplace_back(dst_name); 816 } 817 } 818 } 819 } 820 } 821 cluster_infos->emplace_back(ci); 822 VLOG(1) << DumpCluster(ci); 823 } 824 return Status::OK(); 825 } 826 827 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 828 const ClusterInfo& cluster, const GraphDef& graph_def, 829 GraphDef* subgraph_def) { 830 const std::unordered_set<string>& node_names = std::get<0>(cluster); 831 const std::unordered_set<string>& border_input_names = 832 BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster)); 833 834 Graph graph(OpRegistry::Global()); 835 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 836 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 837 838 for (Node* node : graph.nodes()) { 839 if (node != nullptr && node_names.count(node->name()) <= 0 && 840 border_input_names.count(node->name()) <= 0 && !node->IsSource() && 841 !node->IsSink()) { 842 graph.RemoveNode(node); 843 } 844 } 845 graph.ToGraphDef(subgraph_def); 846 847 for (const string& subgraph_input : std::get<1>(cluster)) { 848 const TensorId tid = ParseTensorName(subgraph_input); 849 const string subgraph_input_name = tid.first.ToString(); 850 const int subgraph_input_port = tid.second; 851 const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); 852 CHECK_NOTNULL(node_def); 853 std::vector<DataType> dt_vec; 854 std::vector<TensorShape> shape_vec; 855 GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError(); 856 const DataType& dt = 857 dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port); 858 const TensorShape& shape = 859 shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port); 860 861 TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt, 862 shape, subgraph_def)); 863 } 864 865 // sort subgraph_def to align order in graph_def 866 std::unordered_map<string, int> name_to_id_map; 867 for (int i = 0; i < graph_def.node_size(); ++i) { 868 name_to_id_map.emplace(graph_def.node(i).name(), i); 869 } 870 std::sort(subgraph_def->mutable_node()->begin(), 871 subgraph_def->mutable_node()->end(), 872 [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) { 873 CHECK(name_to_id_map.count(node0.name()) > 0); 874 CHECK(name_to_id_map.count(node1.name()) > 0); 875 const int id0 = name_to_id_map.at(node0.name()); 876 const int id1 = name_to_id_map.at(node1.name()); 877 return id0 < id1; 878 }); 879 880 VLOG(1) << DumpGraphDef(*subgraph_def); 881 return Status::OK(); 882 } 883 884 /* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 885 const std::vector<string>& border_inputs, 886 const std::vector<string>& border_outputs, const GraphDef& graph_def, 887 ClusterInfo* cluster) { 888 Graph graph(OpRegistry::Global()); 889 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 890 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 891 892 std::unordered_set<const Node*> visited; 893 std::deque<const Node*> queue; 894 for (const string& output : border_outputs) { 895 const TensorId tid = ParseTensorName(output); 896 const string& output_node_name = tid.first.ToString(); 897 for (const Node* node : graph.nodes()) { 898 if (output_node_name == node->name()) { 899 queue.push_back(node); 900 visited.insert(node); 901 } 902 } 903 } 904 905 std::unordered_set<const Node*> border_input_nodes; 906 // propagate visit to parent nodes until input nodes 907 while (!queue.empty()) { 908 const Node* node = queue.front(); 909 queue.pop_front(); 910 for (const Edge* edge : node->in_edges()) { 911 const Node* src_node = edge->src(); 912 CHECK_NOTNULL(src_node); 913 const int src_port = edge->src_output(); 914 bool input_found = false; 915 for (const string& input : border_inputs) { 916 const TensorId tid = ParseTensorName(input); 917 if (tid.first.ToString() == src_node->name() && 918 tid.second == src_port) { 919 input_found = true; 920 border_input_nodes.insert(src_node); 921 } 922 } 923 if (visited.insert(src_node).second) { 924 if (!input_found) { 925 queue.push_back(src_node); 926 } 927 } 928 } 929 } 930 931 for (const Node* node : visited) { 932 if (node != nullptr && !node->IsSource() && !node->IsSink() && 933 border_input_nodes.count(node) <= 0) { 934 std::get<0>(*cluster).insert(node->name()); 935 } 936 } 937 std::get<1>(*cluster) = border_inputs; 938 std::get<2>(*cluster) = border_outputs; 939 return Status::OK(); 940 } 941 942 /* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster( 943 const GraphDef& input_graph_def, const std::vector<string>& inputs, 944 const std::vector<string>& outputs, 945 const string& remote_fused_graph_node_name, const ClusterInfo& cluster, 946 const string& remote_graph_executor_name, const bool require_shape_type, 947 GraphDef* output_graph_def) { 948 LOG(INFO) << "Transforming quantized stripped model to a remote fused " 949 "graph execute op by fusing a specified subgraph..."; 950 951 CHECK(!remote_graph_executor_name.empty()); 952 953 const std::vector<string>& border_inputs = std::get<1>(cluster); 954 const std::vector<string>& border_outputs = std::get<2>(cluster); 955 956 GraphDef subgraph_def; 957 TF_RETURN_IF_ERROR( 958 BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def)); 959 960 Graph graph(OpRegistry::Global()); 961 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 962 TF_RETURN_IF_ERROR( 963 ImportGraphDef({}, input_graph_def, &graph, &shape_refiner)); 964 965 Node* fused_node; 966 TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode( 967 remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def, 968 border_inputs, border_outputs, require_shape_type, &graph, &fused_node)); 969 970 for (const Node* node : graph.nodes()) { 971 for (int i = 0; i < node->num_inputs(); ++i) { 972 const Edge* edge = nullptr; 973 TF_RETURN_IF_ERROR(node->input_edge(i, &edge)); 974 for (int j = 0; j < border_outputs.size(); ++j) { 975 const string& output = border_outputs.at(j); 976 const TensorId tid = ParseTensorName(output); 977 const string output_name = tid.first.ToString(); 978 Node* src_node = edge->src(); 979 if (src_node != nullptr && src_node->name() == output_name && 980 edge->src_output() == tid.second) { 981 // Source node is replaced by new fused node. 982 Node* dst_node = edge->dst(); 983 const int dst_input = edge->dst_input(); 984 LOG(INFO) << "Removing existing edge to " << edge->dst()->name() 985 << " from " << edge->src()->name(); 986 graph.RemoveEdge(edge); 987 graph.AddEdge(fused_node, j, dst_node, dst_input); 988 } 989 } 990 } 991 } 992 993 // Replace output nodes by identity nodes which forward outputs from 994 // RemoteFusedGraphExecuteOpNode 995 for (const string& output : outputs) { 996 const TensorId output_tid = ParseTensorName(output); 997 const string output_name = output_tid.first.ToString(); 998 for (size_t i = 0; i < border_outputs.size(); ++i) { 999 const TensorId subgraph_output_tid = 1000 ParseTensorName(border_outputs.at(i)); 1001 const string& subgraph_output_name = subgraph_output_tid.first.ToString(); 1002 if (output_name == subgraph_output_name) { 1003 LOG(INFO) << "As graph output and subgraph output are same, " 1004 << "the graph output node is replaced by identity node"; 1005 Node* original_output_node = FindMutableNodeByName(output_name, &graph); 1006 CHECK_NOTNULL(original_output_node); 1007 CHECK_EQ(1, original_output_node->num_outputs()) 1008 << "Num outputs should be 1 for " << output << "."; 1009 graph.RemoveNode(original_output_node); 1010 Node* new_node; 1011 TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name, 1012 remote_fused_graph_node_name, i, 1013 DT_FLOAT, &graph, &new_node)); 1014 CHECK_NOTNULL(new_node); 1015 } 1016 } 1017 } 1018 1019 GraphDef result_graph_def; 1020 1021 graph.ToGraphDef(&result_graph_def); 1022 1023 ClusterInfo graph_cluster; 1024 TF_RETURN_IF_ERROR( 1025 BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster)); 1026 1027 // Remove unvisited nodes 1028 TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def, 1029 output_graph_def)); 1030 1031 return Status::OK(); 1032 } 1033 1034 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames( 1035 const GraphDef& input_graph_def, const std::vector<string>& inputs, 1036 const std::vector<string>& outputs, 1037 const string& remote_fused_graph_node_name_prefix, 1038 const std::unordered_set<string>& subgraph_nodes, 1039 const string& remote_fused_graph_executor_name, 1040 const bool require_shape_type, GraphDef* output_graph_def) { 1041 std::vector<ClusterInfo> ci_vec; 1042 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 1043 subgraph_nodes, input_graph_def, &ci_vec)); 1044 1045 for (size_t i = 0; i < ci_vec.size(); ++i) { 1046 const string remote_fused_graph_node_name = 1047 strings::StrCat(remote_fused_graph_node_name_prefix, "/", i); 1048 TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs, 1049 remote_fused_graph_node_name, ci_vec.at(i), 1050 remote_fused_graph_executor_name, 1051 require_shape_type, output_graph_def)); 1052 } 1053 return Status::OK(); 1054 } 1055 1056 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder( 1057 const GraphDef& input_graph_def, const std::vector<string>& inputs, 1058 const std::vector<string>& outputs, 1059 const string& remote_fused_graph_node_name, 1060 const std::vector<string>& border_inputs, 1061 const std::vector<string>& border_outputs, 1062 const string& remote_graph_executor_name, const bool require_shape_type, 1063 GraphDef* output_graph_def) { 1064 ClusterInfo cluster; 1065 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 1066 border_inputs, border_outputs, input_graph_def, &cluster)); 1067 1068 return FuseCluster( 1069 input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster, 1070 remote_graph_executor_name, require_shape_type, output_graph_def); 1071 } 1072 1073 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes( 1074 const GraphDef& input_graph_def, const std::vector<string>& inputs, 1075 const std::vector<string>& outputs, 1076 const string& remote_fused_graph_node_name_prefix, 1077 const std::unordered_set<string>& fused_op_types, 1078 const string& remote_fused_graph_executor_name, 1079 const bool require_shape_type, GraphDef* output_graph_def) { 1080 const std::unordered_set<string> fused_nodes_filtered_by_op_types = 1081 BuildNodeMapFromOpTypes(input_graph_def, fused_op_types); 1082 1083 return FuseRemoteGraphByNodeNames( 1084 input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix, 1085 fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name, 1086 require_shape_type, output_graph_def); 1087 } 1088 1089 /* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor( 1090 const GraphDef& input_graph_def, const std::vector<string>& inputs, 1091 const std::vector<string>& outputs, const string& executor_name, 1092 GraphDef* output_graph_def) { 1093 const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name); 1094 if (build_func == nullptr) { 1095 return errors::InvalidArgument("Unknown executor name: " + executor_name); 1096 } 1097 std::unique_ptr<IRemoteFusedGraphExecutor> executor; 1098 TF_RETURN_IF_ERROR((*build_func)(&executor)); 1099 CHECK_NOTNULL(executor.get()); 1100 if (!executor->IsEnabled()) { 1101 // As this executor is not enabled, just return original graph as is. 1102 *output_graph_def = input_graph_def; 1103 return Status::OK(); 1104 } 1105 return executor->FuseRemoteGraph(input_graph_def, inputs, outputs, 1106 output_graph_def); 1107 } 1108 1109 /* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments( 1110 const std::vector<string>& inputs, const std::vector<string>& outputs, 1111 const std::unordered_set<string>& fused_node_names, 1112 const std::vector<string>& border_inputs, 1113 const std::vector<string>& border_outputs, 1114 const std::unordered_set<string>& fused_op_types, 1115 const string& remote_fused_graph_node_name, 1116 const string& remote_graph_executor_name, GraphDef* graph_def) { 1117 CHECK_NOTNULL(graph_def); 1118 1119 const std::unordered_set<string> fused_nodes_filtered_by_op_types = 1120 BuildNodeMapFromOpTypes(*graph_def, fused_op_types); 1121 1122 for (NodeDef& node_def : *graph_def->mutable_node()) { 1123 string attr_str; 1124 TensorId tid; 1125 for (size_t i = 0; i < inputs.size(); ++i) { 1126 if (IsSameNodeName(node_def, inputs.at(i), &tid)) { 1127 AppendDeliminator(&attr_str); 1128 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT, 1129 tid.second, i, remote_graph_executor_name, 1130 remote_fused_graph_node_name); 1131 } 1132 } 1133 for (size_t i = 0; i < outputs.size(); ++i) { 1134 if (IsSameNodeName(node_def, outputs.at(i), &tid)) { 1135 AppendDeliminator(&attr_str); 1136 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT, 1137 tid.second, i); 1138 } 1139 } 1140 for (const string& fused_node_name : fused_node_names) { 1141 if (fused_node_name == node_def.name()) { 1142 AppendDeliminator(&attr_str); 1143 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); 1144 } 1145 } 1146 for (const string& fused_node_name : fused_nodes_filtered_by_op_types) { 1147 if (fused_node_name == node_def.name()) { 1148 AppendDeliminator(&attr_str); 1149 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); 1150 } 1151 } 1152 for (size_t i = 0; i < border_inputs.size(); ++i) { 1153 if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) { 1154 AppendDeliminator(&attr_str); 1155 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT, 1156 tid.second, i); 1157 } 1158 } 1159 for (size_t i = 0; i < border_outputs.size(); ++i) { 1160 if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) { 1161 AppendDeliminator(&attr_str); 1162 attr_str += BuildNodeTypeAttr( 1163 RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i); 1164 } 1165 } 1166 if (attr_str.empty()) { 1167 attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED); 1168 } 1169 AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def); 1170 } 1171 return Status::OK(); 1172 } 1173 1174 /* static */ Status 1175 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( 1176 const GraphDef& input_graph_def, 1177 const std::vector<std::pair<string, Tensor>>& input_tensors, 1178 GraphDef* output_graph_def) { 1179 std::unordered_map<int, string> input_map; 1180 std::unordered_map<int, string> output_map; 1181 std::unordered_set<string> fused_node_names; 1182 std::unordered_map<int, string> border_input_map; 1183 std::unordered_map<int, string> border_output_map; 1184 string remote_graph_executor_name; 1185 string remote_fused_graph_node_name; 1186 1187 for (const NodeDef& node_def : input_graph_def.node()) { 1188 string attr_str; 1189 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str)); 1190 std::vector<std::vector<string>> attr_strs; 1191 for (const string& str : str_util::Split(attr_str, ":")) { 1192 attr_strs.emplace_back(str_util::Split(str, ",")); 1193 } 1194 if (attr_strs.empty()) { 1195 return errors::InvalidArgument("Remote graph node type not found."); 1196 } 1197 for (const std::vector<string>& attr : attr_strs) { 1198 if (attr.empty()) { 1199 return errors::InvalidArgument("Empty remote graph node type attr."); 1200 } 1201 int node_type_int; 1202 CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0); 1203 const RemoteFusedGraphExecuteInfo::NodeType node_type = 1204 static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int); 1205 const string& name = node_def.name(); 1206 int port; 1207 int index; 1208 1209 switch (node_type) { 1210 case RemoteFusedGraphExecuteInfo::GRAPH_INPUT: 1211 VLOG(2) << "Graph input: " << name; 1212 CHECK_EQ(5, attr.size()); 1213 CHECK(strings::safe_strto32(attr.at(1), &port)); 1214 CHECK(strings::safe_strto32(attr.at(2), &index)); 1215 CHECK(!attr.at(3).empty()); 1216 remote_graph_executor_name = attr.at(3); 1217 CHECK(!attr.at(4).empty()); 1218 remote_fused_graph_node_name = attr.at(4); 1219 input_map.emplace(index, strings::StrCat(name, ":", port)); 1220 if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) { 1221 LOG(INFO) << "Executor for " << remote_graph_executor_name 1222 << " not registered. Do not fuse."; 1223 *output_graph_def = input_graph_def; 1224 return Status::OK(); 1225 } 1226 break; 1227 case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT: 1228 VLOG(2) << "Graph output: " << name; 1229 CHECK_EQ(3, attr.size()); 1230 CHECK(strings::safe_strto32(attr.at(1), &port)); 1231 CHECK(strings::safe_strto32(attr.at(2), &index)); 1232 output_map.emplace(index, strings::StrCat(name, ":", port)); 1233 break; 1234 case RemoteFusedGraphExecuteInfo::FUSED_NODE: 1235 VLOG(2) << "Fused node: " << name; 1236 CHECK_EQ(1, attr.size()); 1237 fused_node_names.emplace(name); 1238 break; 1239 case RemoteFusedGraphExecuteInfo::BORDER_INPUT: 1240 VLOG(2) << "Border input: " << name; 1241 CHECK_EQ(3, attr.size()); 1242 CHECK(strings::safe_strto32(attr.at(1), &port)); 1243 CHECK(strings::safe_strto32(attr.at(2), &index)); 1244 border_input_map.emplace(index, strings::StrCat(name, ":", port)); 1245 break; 1246 case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT: 1247 VLOG(2) << "Border output: " << name; 1248 CHECK_EQ(3, attr.size()); 1249 CHECK(strings::safe_strto32(attr.at(1), &port)); 1250 CHECK(strings::safe_strto32(attr.at(2), &index)); 1251 border_output_map.emplace(index, strings::StrCat(name, ":", port)); 1252 break; 1253 case RemoteFusedGraphExecuteInfo::UNUSED: 1254 // do nothing 1255 break; 1256 default: 1257 // unsupported value 1258 LOG(FATAL); 1259 } 1260 } 1261 } 1262 bool require_shape_type = false; 1263 std::vector<string> inputs; 1264 std::vector<string> outputs; 1265 std::vector<string> border_inputs; 1266 std::vector<string> border_outputs; 1267 ConvertMapToVector(input_map, &inputs); 1268 ConvertMapToVector(output_map, &outputs); 1269 ConvertMapToVector(border_input_map, &border_inputs); 1270 ConvertMapToVector(border_output_map, &border_outputs); 1271 1272 if (!input_tensors.empty()) { 1273 bool input_match = false; 1274 if (inputs.size() == input_tensors.size()) { 1275 for (const std::pair<string, Tensor>& input_tensor : input_tensors) { 1276 if (!ContainsSameTensorId(input_tensor.first, inputs)) { 1277 break; 1278 } 1279 DataType data_type; 1280 TensorShape shape; 1281 if (GetOutputTensorShapeType(input_graph_def, input_tensor.first, 1282 &data_type, &shape)) { 1283 if (data_type == input_tensor.second.dtype() && 1284 shape == input_tensor.second.shape()) { 1285 VLOG(2) << "Input matched!"; 1286 // Shape type matched. 1287 input_match = true; 1288 require_shape_type = true; 1289 } 1290 } else { 1291 // Shape type not required. 1292 input_match = true; 1293 } 1294 } 1295 } 1296 if (!input_match) { 1297 // Input mismatch. Just copy original graph 1298 *output_graph_def = input_graph_def; 1299 return Status::OK(); 1300 } 1301 } 1302 1303 if (!fused_node_names.empty()) { 1304 TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames( 1305 input_graph_def, inputs, outputs, remote_fused_graph_node_name, 1306 fused_node_names, remote_graph_executor_name, require_shape_type, 1307 output_graph_def)); 1308 } else if (!border_inputs.empty() || !border_outputs.empty()) { 1309 TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder( 1310 input_graph_def, inputs, outputs, remote_fused_graph_node_name, 1311 border_inputs, border_outputs, remote_graph_executor_name, 1312 require_shape_type, output_graph_def)); 1313 } else { 1314 *output_graph_def = input_graph_def; 1315 } 1316 1317 return Status::OK(); 1318 } 1319 1320 /* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady( 1321 const GraphDef& graph_def, 1322 const std::vector<std::pair<string, Tensor>>& input_tensors) { 1323 for (const std::pair<string, Tensor>& input_tensor : input_tensors) { 1324 const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def); 1325 if (node_def == nullptr) { 1326 return false; 1327 } 1328 string attr; 1329 const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr); 1330 if (!status.ok() || attr.empty()) { 1331 return false; 1332 } 1333 } 1334 return true; 1335 } 1336 1337 /* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor( 1338 const void* src_ptr, const int src_size, Tensor* tensor) { 1339 CHECK(tensor->TotalBytes() >= src_size) 1340 << tensor->TotalBytes() << ", " << src_size; 1341 void* dst_ptr; 1342 switch (tensor->dtype()) { 1343 case DT_FLOAT: 1344 dst_ptr = tensor->flat<float>().data(); 1345 break; 1346 case DT_DOUBLE: 1347 dst_ptr = tensor->flat<double>().data(); 1348 break; 1349 case DT_INT32: 1350 dst_ptr = tensor->flat<int32>().data(); 1351 break; 1352 case DT_UINT8: 1353 dst_ptr = tensor->flat<uint8>().data(); 1354 break; 1355 case DT_INT16: 1356 dst_ptr = tensor->flat<int16>().data(); 1357 break; 1358 case DT_INT8: 1359 dst_ptr = tensor->flat<int8>().data(); 1360 break; 1361 case DT_STRING: 1362 dst_ptr = tensor->flat<string>().data(); 1363 break; 1364 case DT_INT64: 1365 dst_ptr = tensor->flat<int64>().data(); 1366 break; 1367 case DT_BOOL: 1368 dst_ptr = tensor->flat<bool>().data(); 1369 break; 1370 case DT_QINT8: 1371 dst_ptr = tensor->flat<qint8>().data(); 1372 break; 1373 case DT_QUINT8: 1374 dst_ptr = tensor->flat<quint8>().data(); 1375 break; 1376 case DT_QINT32: 1377 dst_ptr = tensor->flat<qint32>().data(); 1378 break; 1379 case DT_BFLOAT16: 1380 dst_ptr = tensor->flat<bfloat16>().data(); 1381 break; 1382 case DT_QINT16: 1383 dst_ptr = tensor->flat<qint16>().data(); 1384 break; 1385 case DT_QUINT16: 1386 dst_ptr = tensor->flat<quint16>().data(); 1387 break; 1388 case DT_UINT16: 1389 dst_ptr = tensor->flat<uint16>().data(); 1390 break; 1391 default: 1392 LOG(FATAL) << "type " << tensor->dtype() << " is not supported."; 1393 break; 1394 } 1395 CHECK_NOTNULL(dst_ptr); 1396 std::memcpy(dst_ptr, src_ptr, src_size); 1397 return Status::OK(); 1398 } 1399 1400 /* static */ std::unordered_set<string> 1401 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes( 1402 const GraphDef& graph_def, const std::unordered_set<string>& op_types) { 1403 std::unordered_set<string> retval; 1404 for (const NodeDef& node_def : graph_def.node()) { 1405 if (op_types.count(node_def.op()) > 0) { 1406 retval.emplace(node_def.name()); 1407 } 1408 } 1409 return retval; 1410 } 1411 1412 /* static */ std::unordered_set<string> 1413 RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( 1414 const GraphDef& graph_def, 1415 const IRemoteFusedGraphOpsDefinitions& ops_definitions) { 1416 std::unordered_set<string> retval; 1417 for (const NodeDef& node_def : graph_def.node()) { 1418 std::vector<DataType> dt_vec; 1419 std::vector<TensorShape> shape_vec; 1420 const Status status = 1421 GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec); 1422 if (!status.ok()) { 1423 shape_vec.clear(); 1424 } 1425 if (ops_definitions.GetOpIdFor( 1426 node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) != 1427 IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) { 1428 retval.emplace(node_def.name()); 1429 } 1430 } 1431 return retval; 1432 } 1433 1434 /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder( 1435 const string& input, const DataType type, const TensorShape& shape, 1436 GraphDef* graph_def) { 1437 const TensorId tid = ParseTensorName(input); 1438 CHECK_EQ(0, tid.second); 1439 const string node_name = tid.first.ToString(); 1440 for (NodeDef& node : *graph_def->mutable_node()) { 1441 if (node.name() != node_name) { 1442 continue; 1443 } 1444 if (node.op() == "Placeholder") { 1445 return Status::OK(); 1446 } else { 1447 NodeDef placeholder_node; 1448 placeholder_node.set_op("Placeholder"); 1449 placeholder_node.set_name(node_name); 1450 AddNodeAttr("dtype", type, &placeholder_node); 1451 AddNodeAttr("shape", shape, &placeholder_node); 1452 // TODO(satok): Remove once we merge attributes 1453 AddOutputTensorShapeType({type}, {shape}, &placeholder_node); 1454 node.Clear(); 1455 node = placeholder_node; 1456 return Status::OK(); 1457 } 1458 } 1459 return errors::InvalidArgument( 1460 strings::StrCat(node_name, " not found for replacement.")); 1461 } 1462 1463 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 1464 const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 1465 const int index, const string& executor_name, const string& node_name) { 1466 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index, 1467 ",", executor_name, ",", node_name); 1468 } 1469 1470 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 1471 const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 1472 const int index) { 1473 return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index); 1474 } 1475 1476 /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 1477 const RemoteFusedGraphExecuteInfo::NodeType node_type) { 1478 return strings::StrCat(static_cast<int>(node_type)); 1479 } 1480 1481 } // namespace tensorflow 1482