1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/hexagon/graph_transferer.h" 17 18 #include <algorithm> 19 #include <cinttypes> 20 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/graph/algorithm.h" 23 #include "tensorflow/core/graph/graph_constructor.h" 24 #include "tensorflow/core/graph/node_builder.h" 25 #include "tensorflow/core/platform/env.h" 26 #include "tensorflow/core/platform/types.h" 27 #include "tensorflow/core/public/session.h" 28 #include "tensorflow/core/public/session_options.h" 29 #include "tensorflow/core/util/tensor_slice_writer.h" 30 31 namespace tensorflow { 32 33 // function alias 34 constexpr auto AddOutputTensorShapeTypeByTensorShapeMap = 35 &RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap; 36 37 constexpr bool DBG_DUMP_VERIFICATION_STRING = false; 38 constexpr bool DBG_DUMP_PARAMS = false; 39 40 const char RESHAPE_NODE_TYPE_STRING[] = "Reshape"; 41 const char SOURCE_NODE_NAME[] = "_SOURCE"; 42 const char SINK_NODE_NAME[] = "_SINK"; 43 const char INPUTS_NODE_PREFIX[] = "inputs_for_"; 44 const char OUTPUTS_NODE_PREFIX[] = "outputs_for_"; 45 const char DATA_NODE_PREFIX[] = "data_for_op_"; 46 const char CONST_SHAPE_PREFIX[] = "const_shape_"; 47 const char CONST_VAL_PREFIX[] = "const_val_"; 48 const char CONST_TENSOR_PREFIX[] = "const_tensor_"; 49 const char PADDING_ATTR_NAME[] = "padding"; 50 const char STRIDES_ATTR_NAME[] = "strides"; 51 const char KEEP_DIMS_ATTR_NAME[] = "keep_dims"; 52 const char KSIZE_ATTR_NAME[] = "ksize"; 53 const char NULL_OUTPUT_NAME[] = "NULL"; 54 const char AGGREGATED_INPUT_NODE_NAME[] = "graph_transfer_aggregated_input"; 55 const int PADDING_NA_ID = 0; // VALID = 1, SAME = 2 56 57 // This is a temporary workaround to support android build 58 // where std::string is not supported even with c++11 option. 59 template <typename T> 60 static string ToString(T val) { 61 std::stringstream stream; 62 stream << val; 63 return stream.str(); 64 } 65 66 static Node* FindMutableNodeByName(const string& name, Graph* graph) { 67 const TensorId tid = ParseTensorName(name); 68 for (Node* node : graph->nodes()) { 69 if (node != nullptr && node->name() == tid.first) { 70 return node; 71 } 72 } 73 return nullptr; 74 } 75 76 /** 77 * graph loading functions 78 * - LoadGraphFromProto 79 * - LoadGraphFromProptoFile 80 * These functions read a graph definition and store parameters 81 * of node to transfer the graph to SOC. 82 */ 83 Status GraphTransferer::LoadGraphFromProto( 84 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 85 const GraphDef& graph_def, 86 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 87 const std::vector<string>& output_node_names, 88 const bool shape_inference_for_unknown_shape) { 89 Graph graph(OpRegistry::Global()); 90 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 91 Status status = ImportGraphDef({}, graph_def, &graph, &shape_refiner); 92 if (!status.ok()) { 93 return status; 94 } 95 96 if (shape_inference_for_unknown_shape) { 97 status = RemoteFusedGraphExecuteUtils::PropagateShapeInference( 98 graph_def, input_node_info_list, &graph, &shape_refiner); 99 if (!status.ok()) { 100 return status; 101 } 102 } 103 104 TF_RETURN_IF_ERROR(TransformGraphToAddAggregatedInputNode( 105 input_node_info_list, &graph, &shape_refiner)); 106 107 std::unordered_multimap<string, const Node*> op_name_to_node_multimap( 108 graph.num_nodes()); 109 for (const Node* const node : graph.nodes()) { 110 if (node == nullptr) { 111 continue; 112 } 113 CacheNode(*node); 114 } 115 116 for (const Node* const node : graph.nodes()) { 117 if (node == nullptr) { 118 continue; 119 } 120 VLOG(1) << "<Node> " << node->name(); 121 for (const Node* const input_node : node->in_nodes()) { 122 const string& name = input_node->name(); 123 op_name_to_node_multimap.emplace(name, node); 124 VLOG(1) << "Add dependency: " << name << " -> " << node->name(); 125 } 126 } 127 128 for (const Node* const node : graph.nodes()) { 129 if (node == nullptr) { 130 continue; 131 } 132 status = RegisterNodeIfAllInputsAreCached( 133 ops_definitions, shape_refiner, *node, false, input_node_info_list, 134 output_node_names); 135 if (!status.ok()) { 136 LOG(ERROR) << "Failed to transfer graph " << status; 137 return status; 138 } 139 } 140 141 SortParams(output_node_names); 142 143 for (const std::pair<string, Tensor>& input_node_info : 144 input_node_info_list) { 145 GraphTransferInfo::GraphInputNodeInfo& graph_input_node_info = 146 *graph_transfer_info_.add_graph_input_node_info(); 147 graph_input_node_info.set_name(input_node_info.first); 148 graph_input_node_info.set_dtype(input_node_info.second.dtype()); 149 for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) { 150 graph_input_node_info.add_shape(dim); 151 } 152 } 153 154 for (const string& output_node_name : output_node_names) { 155 const TensorId tid = ParseTensorName(output_node_name); 156 const string node_name = tid.first.ToString(); 157 const int port = tid.second; 158 const int node_id = node_name_to_id_cache_map_.at(node_name); 159 const Node* node = node_name_cache_list_.at(node_id); 160 CHECK_NOTNULL(node); 161 162 GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info = 163 *graph_transfer_info_.add_graph_output_node_info(); 164 graph_output_node_info.set_name(strings::StrCat(node_name, ":", port)); 165 166 // Get output tensor shape type 167 std::vector<DataType> data_types; 168 std::vector<TensorShape> shapes; 169 status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 170 node->attrs(), &data_types, &shapes); 171 if (status.ok()) { 172 CHECK(data_types.size() > port); 173 graph_output_node_info.set_dtype(data_types.at(port)); 174 for (const int64 dim : ToTensorShapeArray(shapes.at(port))) { 175 graph_output_node_info.add_shape(dim); 176 } 177 } 178 } 179 180 ClearCache(); 181 if (DBG_DUMP_PARAMS) { 182 DumpNodeTransferParams(); 183 } 184 if (DBG_DUMP_VERIFICATION_STRING) { 185 DumpVerificationStringOfNodeTransferParams(); 186 } 187 return Status(); 188 } 189 190 Status GraphTransferer::LoadGraphFromProtoFile( 191 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 192 const string& graph_def_path, 193 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 194 const std::vector<string>& output_node_names, const bool is_text_proto, 195 const bool shape_inference_for_unknown_shape, 196 const bool dry_run_for_unknown_shape) { 197 GraphDef graph_def; 198 string output; 199 Status status; 200 VLOG(1) << "Parse file " << graph_def_path; 201 if (is_text_proto) { 202 status = ReadFileToString(Env::Default(), graph_def_path, &output); 203 if (!protobuf::TextFormat::ParseFromString(output, &graph_def)) { 204 return errors::InvalidArgument("Cannot parse proto string."); 205 } 206 } else { 207 status = ReadBinaryProto(Env::Default(), graph_def_path, &graph_def); 208 } 209 if (!status.ok()) { 210 VLOG(1) << "Failed to load graph " << status; 211 return status; 212 } 213 if (dry_run_for_unknown_shape) { 214 VLOG(1) << "Dry run graph to obtain shape of nodes"; 215 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; 216 status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( 217 graph_def, input_node_info_list, true, &tensor_shape_map); 218 if (!status.ok()) { 219 return status; 220 } 221 for (NodeDef& node_def : *graph_def.mutable_node()) { 222 TF_CHECK_OK(AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, 223 &node_def)); 224 } 225 } 226 VLOG(1) << "Load graph with output tensors"; 227 return LoadGraphFromProto(ops_definitions, graph_def, input_node_info_list, 228 output_node_names, 229 shape_inference_for_unknown_shape); 230 } 231 232 void GraphTransferer::SortParams(const std::vector<string>& output_node_names) { 233 // TODO(satok): optimize complexity 234 std::unordered_map<int, GraphTransferInfo::NodeInputInfo*> input_map; 235 for (GraphTransferInfo::NodeInputInfo& input : 236 *graph_transfer_info_.mutable_node_input_info()) { 237 input_map.emplace(input.node_id(), &input); 238 } 239 240 // Setup dependency map placeholder 241 std::vector<int> output_node_ids; 242 std::unordered_map<int, std::unordered_set<int>> dependency_map; 243 for (const GraphTransferInfo::NodeInfo& params : 244 graph_transfer_info_.node_info()) { 245 const int node_id = params.node_id(); 246 for (const string& output_node_name : output_node_names) { 247 if (params.name() == output_node_name) { 248 output_node_ids.emplace_back(node_id); 249 } 250 } 251 252 dependency_map.emplace(std::piecewise_construct, std::make_tuple(node_id), 253 std::make_tuple()); 254 if (params.input_count() == 0) { 255 continue; 256 } 257 CHECK_EQ(input_map.count(node_id), 1); 258 for (const GraphTransferInfo::NodeInput& node_input : 259 input_map.at(node_id)->node_input()) { 260 dependency_map.at(node_id).emplace(node_input.node_id()); 261 } 262 } 263 264 // Create dependency map traversed from output nodes 265 std::unordered_set<int> completed; 266 for (int output_node_id : output_node_ids) { 267 FillDependencyRec(output_node_id, dependency_map, completed); 268 } 269 270 std::sort(graph_transfer_info_.mutable_node_info()->begin(), 271 graph_transfer_info_.mutable_node_info()->end(), 272 TransferParamsComparator(dependency_map)); 273 } 274 275 void GraphTransferer::EnableStrictCheckMode(const bool enable) { 276 strict_check_mode_ = enable; 277 } 278 279 void GraphTransferer::SetSerializedGraphTransferInfo( 280 const string& serialized_proto) { 281 graph_transfer_info_.ParseFromString(serialized_proto); 282 } 283 284 const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const { 285 return graph_transfer_info_; 286 } 287 288 GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() { 289 return graph_transfer_info_; 290 } 291 292 void GraphTransferer::CacheNode(const Node& node) { 293 if (node_name_to_id_cache_map_.count(node.name()) > 0) { 294 return; 295 } 296 node_name_cache_list_.emplace_back(&node); 297 const int node_id = node_name_cache_list_.size() - 1; 298 bool emplace_succeeded = false; 299 std::tie(std::ignore, emplace_succeeded) = 300 node_name_to_id_cache_map_.emplace(node.name(), node_id); 301 CHECK(emplace_succeeded); 302 } 303 304 bool GraphTransferer::AreAllInputsCached(const Node& node) const { 305 for (const Node* const input_node : node.in_nodes()) { 306 if (node_name_to_id_cache_map_.count(input_node->name()) <= 0) { 307 VLOG(1) << "input_node " << input_node->name() << " of " << node.name() 308 << " is not cached yet."; 309 return false; 310 } 311 } 312 return true; 313 } 314 315 Status GraphTransferer::TransformGraphToAddAggregatedInputNode( 316 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 317 Graph* graph, ShapeRefiner* shape_refiner) { 318 // Transform a remote fused graph to add an aggregated input node which takes 319 // all inputs of the remote graph. 320 DataTypeVector input_data_types; 321 std::vector<DataType> data_types; 322 std::vector<TensorShape> shapes; 323 std::vector<string> input_nodes; 324 for (int i = 0; i < input_node_info_list.size(); ++i) { 325 Node* node = FindMutableNodeByName(input_node_info_list.at(i).first, graph); 326 CHECK_NOTNULL(node); 327 input_nodes.emplace_back(node->name()); 328 input_data_types.emplace_back(input_node_info_list.at(i).second.dtype()); 329 data_types.emplace_back(input_node_info_list.at(i).second.dtype()); 330 shapes.emplace_back(input_node_info_list.at(i).second.shape()); 331 } 332 333 NodeDef input_node_def; 334 auto builder = 335 NodeBuilder(AGGREGATED_INPUT_NODE_NAME, "RemoteFusedGraphExecute") 336 .Input(std::vector<NodeBuilder::NodeOut>{}) 337 .Attr("Tinputs", DataTypeVector{}) 338 .Attr("Toutputs", input_data_types) 339 .Attr("serialized_remote_fused_graph_execute_info", "") 340 .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, 341 data_types) 342 .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, shapes); 343 344 Node* input_node; 345 TF_RETURN_IF_ERROR(builder.Finalize(graph, &input_node)); 346 CHECK_NOTNULL(input_node); 347 348 bool refined; 349 TF_RETURN_IF_ERROR( 350 shape_refiner->UpdateNode(input_node, false /* relax */, &refined)); 351 352 shape_inference::InferenceContext* context = 353 shape_refiner->GetContext(input_node); 354 for (int i = 0; i < input_node_info_list.size(); ++i) { 355 shape_inference::ShapeHandle handle; 356 TF_RETURN_IF_ERROR(context->MakeShapeFromTensorShape( 357 input_node_info_list.at(i).second.shape(), &handle)); 358 TF_RETURN_IF_ERROR(shape_refiner->SetShape(input_node, i, handle)); 359 } 360 361 // Cache the aggregate input node first as it's consumed first. 362 CacheNode(*input_node); 363 364 std::vector<Node*> original_input_nodes(input_nodes.size()); 365 366 for (int i = 0; i < input_nodes.size(); ++i) { 367 const string& node_name = input_nodes.at(i); 368 Node* original_input_node = FindMutableNodeByName(node_name, graph); 369 CHECK_NOTNULL(original_input_node); 370 CHECK_EQ(1, original_input_node->num_outputs()); // replaced by identity. 371 Node* created_node; 372 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildIdentityOpNode( 373 node_name, AGGREGATED_INPUT_NODE_NAME, i, data_types.at(i), graph, 374 &created_node)); 375 CHECK_NOTNULL(created_node); 376 std::vector<DataType> data_types; 377 std::vector<TensorShape> shapes; 378 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 379 original_input_node->attrs(), &data_types, &shapes); 380 if (status.ok()) { 381 created_node->AddAttr( 382 RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types); 383 created_node->AddAttr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, 384 shapes); 385 } 386 for (const Edge* out_edge : original_input_node->out_edges()) { 387 Node* dst = out_edge->dst(); 388 int dst_port = out_edge->dst_input(); 389 // Unused edge will be removed when removing node. 390 graph->AddEdge(created_node, 0, dst, dst_port); 391 } 392 original_input_nodes[i] = original_input_node; 393 394 TF_RETURN_IF_ERROR( 395 shape_refiner->UpdateNode(created_node, false /* relax */, &refined)); 396 397 shape_inference::InferenceContext* context = 398 shape_refiner->GetContext(created_node); 399 CHECK_NOTNULL(context); 400 401 // Cache replaced input node next to the aggregated input node. 402 CacheNode(*created_node); 403 } 404 405 // Remove original input nodes after adding new input nodes to avoid 406 // reusing same pointer in Graph. 407 for (Node* original_input_node : original_input_nodes) { 408 graph->RemoveNode(original_input_node); 409 } 410 411 return Status::OK(); 412 } 413 414 Status GraphTransferer::RegisterNode( 415 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 416 const ShapeRefiner& shape_refiner, const Node& node, 417 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 418 const std::vector<string>& output_node_names) { 419 VLOG(1) << "Register node: " << node.name() << ", " << std::hex 420 << node_name_to_id_cache_map_.at(node.name()); 421 if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) { 422 // Just ignore sink and source 423 return Status::OK(); 424 } else if (node.name() == AGGREGATED_INPUT_NODE_NAME) { 425 RegisterInputNode(ops_definitions, shape_refiner, node); 426 return Status::OK(); 427 } else if (node.IsConstant()) { 428 RegisterConstantNode(shape_refiner, node); 429 } else if (IsPadNode(node)) { 430 RegisterPadNode(ops_definitions, shape_refiner, node); 431 } else if (HasPaddingAndStrides(node)) { 432 RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node); 433 } else if (NeedsToAddRank(node)) { 434 RegisterNodeWithRank(ops_definitions, shape_refiner, node); 435 } else if (IsNodeFlattenReshape(node, shape_refiner)) { 436 RegisterFlattenNode(ops_definitions, shape_refiner, node); 437 } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) != 438 IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) { 439 // TODO(satok): Set correct data type if it's given. 440 RegisterGenericNode(ops_definitions, shape_refiner, node); 441 } else { 442 return errors::InvalidArgument(node.type_string() + 443 " has not been implemented yet."); 444 } 445 446 return Status::OK(); 447 } 448 449 void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, 450 const Node& node) { 451 VLOG(1) << "Register constant node: " << node.name(); 452 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 453 const int id = node_name_to_id_cache_map_[node.name()]; 454 const int output_node_size = node.num_outputs(); 455 CHECK_EQ(output_node_size, 1); 456 // TODO(satok): support multiple outputs? 457 const int output_index = 0; 458 const DataType dt = node.output_type(output_index); 459 const size_t max_bytes_per_data = DataTypeSize(dt); 460 CHECK_GT(max_bytes_per_data, 0) 461 << "dt = " << dt << ", " + DataTypeString(dt) << ", " 462 << max_bytes_per_data << ", " << static_cast<int>(DataTypeSize(dt)) 463 << ",,,,,,,"; 464 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); 465 shape_inference::ShapeHandle shape_handle = context->output(output_index); 466 const shape_inference::DimensionHandle num_elements_dim = 467 context->NumElements(shape_handle); 468 std::array<int64, SHAPE_ARRAY_SIZE> shape_array; 469 int data_size; 470 // Shape of constant node must be known 471 CHECK(context->ValueKnown(num_elements_dim)); 472 const int64 num_output_elements = context->Value(num_elements_dim); 473 data_size = max_bytes_per_data * num_output_elements; 474 shape_array = BuildShapeArray(shape_handle, context); 475 476 GraphTransferInfo::ConstNodeInfo& const_node_info = 477 *graph_transfer_info_.add_const_node_info(); 478 const_node_info.set_name(node.name()); 479 const_node_info.set_node_id(id); 480 // TODO(satok): Make this generic. Never assume rank is 4. 481 CHECK_EQ(4, SHAPE_ARRAY_SIZE); 482 const_node_info.add_shape(shape_array[0]); 483 const_node_info.add_shape(shape_array[1]); 484 const_node_info.add_shape(shape_array[2]); 485 const_node_info.add_shape(shape_array[3]); 486 const TensorProto* proto = nullptr; 487 TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto)); 488 Tensor const_tensor; 489 TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor)); 490 491 const_node_info.set_dtype(const_tensor.dtype()); 492 if (data_size > 0) { 493 const_node_info.set_data(const_tensor.tensor_data().data(), data_size); 494 } 495 } 496 497 int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) { 498 VLOG(1) << "Cache constant shape."; 499 // TODO(satok): Handle non-4dim strides 500 CHECK_EQ(shape.size(), 4); 501 const string shape_name = CONST_SHAPE_PREFIX + ToString(shape.at(0)) + 'x' + 502 ToString(shape.at(1)) + 'x' + 503 ToString(shape.at(2)) + 'x' + ToString(shape.at(3)); 504 if (node_name_to_id_cache_map_.count(shape_name) <= 0) { 505 node_name_cache_list_.emplace_back(nullptr); 506 const int id = node_name_cache_list_.size() - 1; 507 node_name_to_id_cache_map_.emplace(shape_name, id); 508 GraphTransferInfo::ConstNodeInfo& const_node_info = 509 *graph_transfer_info_.add_const_node_info(); 510 const_node_info.set_name(shape_name); 511 const_node_info.set_node_id(id); 512 // TODO(satok): Make this generic. Never assume rank is 5. 513 const_node_info.add_shape(static_cast<int64>(shape[0])); 514 const_node_info.add_shape(static_cast<int64>(shape[1])); 515 const_node_info.add_shape(static_cast<int64>(shape[2])); 516 const_node_info.add_shape(static_cast<int64>(shape[3])); 517 } 518 return node_name_to_id_cache_map_[shape_name]; 519 } 520 521 int GraphTransferer::RegisterConstTensor(const Tensor& tensor, 522 const string& suffix) { 523 VLOG(1) << "Cache const tensor."; 524 const int dims = tensor.shape().dims(); 525 CHECK(dims <= 4); 526 const string node_name = strings::StrCat(CONST_TENSOR_PREFIX, "_", suffix); 527 if (node_name_to_id_cache_map_.count(node_name) <= 0) { 528 node_name_cache_list_.emplace_back(nullptr); 529 const int id = node_name_cache_list_.size() - 1; 530 node_name_to_id_cache_map_.emplace(node_name, id); 531 GraphTransferInfo::ConstNodeInfo& const_node_info = 532 *graph_transfer_info_.add_const_node_info(); 533 const_node_info.set_name(node_name); 534 const_node_info.set_node_id(id); 535 CHECK_EQ(4, SHAPE_ARRAY_SIZE); 536 for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) { 537 if (i < SHAPE_ARRAY_SIZE - dims) { 538 const_node_info.add_shape(1); 539 } else { 540 const_node_info.add_shape( 541 tensor.shape().dim_size(i - (SHAPE_ARRAY_SIZE - dims))); 542 } 543 } 544 const_node_info.set_dtype(tensor.dtype()); 545 const_node_info.set_data(tensor.tensor_data().data(), 546 tensor.tensor_data().size()); 547 } 548 return node_name_to_id_cache_map_[node_name]; 549 } 550 551 int GraphTransferer::RegisterConstScalar(const DataType dt, const int val, 552 const int dst_id, 553 const int dst_input_count) { 554 VLOG(1) << "Cache const."; 555 const string val_name = 556 CONST_VAL_PREFIX + ToString(dst_id) + '_' + ToString(dst_input_count); 557 if (node_name_to_id_cache_map_.count(val_name) <= 0) { 558 node_name_cache_list_.emplace_back(nullptr); 559 const int id = node_name_cache_list_.size() - 1; 560 node_name_to_id_cache_map_.emplace(val_name, id); 561 GraphTransferInfo::ConstNodeInfo& const_node_info = 562 *graph_transfer_info_.add_const_node_info(); 563 const_node_info.set_name(val_name); 564 const_node_info.set_node_id(id); 565 // TODO(satok): Do not assume rank is 4 here. 566 const_node_info.add_shape(static_cast<int64>(1)); 567 const_node_info.add_shape(static_cast<int64>(1)); 568 const_node_info.add_shape(static_cast<int64>(1)); 569 const_node_info.add_shape(static_cast<int64>(1)); 570 const_node_info.set_data(&val, DataTypeSize(dt)); 571 } 572 return node_name_to_id_cache_map_[val_name]; 573 } 574 575 bool GraphTransferer::HasPaddingAndStrides(const Node& node) { 576 auto attrs = node.attrs(); 577 return attrs.Find(PADDING_ATTR_NAME) != nullptr && 578 attrs.Find(STRIDES_ATTR_NAME) != nullptr; 579 } 580 581 bool GraphTransferer::NeedsToAddRank(const Node& node) { 582 const StringPiece op_type(node.type_string()); 583 if (op_type == "Transpose" || op_type == "ExpandDims") { 584 return true; 585 } 586 return false; 587 } 588 589 bool GraphTransferer::IsPadNode(const Node& node) { 590 const StringPiece op_type(node.type_string()); 591 if (op_type == "Pad") { 592 return true; 593 } 594 return false; 595 } 596 597 bool GraphTransferer::IsNodeFlattenReshape(const Node& node, 598 const ShapeRefiner& shape_refiner) { 599 // Check if node is reshape op 600 if (node.type_string() != RESHAPE_NODE_TYPE_STRING) { 601 return false; 602 } 603 604 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); 605 // Check if output count is valid 606 if (context->num_outputs() != 1) { 607 return false; 608 } 609 610 shape_inference::ShapeHandle shape_handle = context->output(0); 611 std::array<int64, SHAPE_ARRAY_SIZE> shape_array; 612 const shape_inference::DimensionHandle dim_handle = 613 context->NumElements(shape_handle); 614 615 // Obtain shape of output of node 616 if (context->ValueKnown(dim_handle)) { 617 shape_array = BuildShapeArray(shape_handle, context); 618 } else { 619 std::vector<TensorShape> shapes; 620 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 621 node.attrs(), nullptr, &shapes)); 622 623 // Number of outputs should be 1 for reshape node. 624 CHECK_EQ(1, shapes.size()); 625 shape_array = ToTensorShapeArray(shapes.at(0)); 626 } 627 628 // check if reshape op just does flatten 629 if (shape_array[0] == 1 && shape_array[1] == 1 && shape_array[2] == 1) { 630 return true; 631 } else { 632 return false; 633 } 634 } 635 636 void GraphTransferer::RegisterNodeWithPaddingAndStrides( 637 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 638 const ShapeRefiner& shape_refiner, const Node& node) { 639 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 640 const int id = node_name_to_id_cache_map_[node.name()]; 641 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); 642 CHECK(node.attrs().Find(PADDING_ATTR_NAME)); 643 // TODO(satok): Use context->GetAttr(...) instead? 644 Padding padding; 645 TF_CHECK_OK(context->GetAttr(PADDING_ATTR_NAME, &padding)); 646 CHECK(node.attrs().Find(STRIDES_ATTR_NAME)); 647 std::vector<int32> strides; 648 TF_CHECK_OK(context->GetAttr(STRIDES_ATTR_NAME, &strides)); 649 const int stride_id = RegisterConstantShape(strides); 650 std::vector<int> extra_inputs{stride_id}; 651 if (node.attrs().Find(KSIZE_ATTR_NAME)) { 652 std::vector<int32> kernel_sizes; 653 TF_CHECK_OK(context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes)); 654 const int ksize_id = RegisterConstantShape(kernel_sizes); 655 extra_inputs.insert(extra_inputs.begin(), ksize_id); 656 } 657 // TODO(satok): Set correct data type if it's given. 658 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); 659 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) 660 << "Op " << node.type_string() << " not found in map(id = " << op_type_id 661 << ")"; 662 // Safety check of padding id 663 CHECK(padding == Padding::VALID ? 1 : 2); 664 AppendNodeParamsWithIoParams( 665 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 666 static_cast<int>(padding), node.num_inputs(), extra_inputs, 667 node.num_outputs(), true /* append_input */, true /* append_output */); 668 } 669 670 void GraphTransferer::RegisterNodeWithRank( 671 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 672 const ShapeRefiner& shape_refiner, const Node& node) { 673 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 674 const int id = node_name_to_id_cache_map_[node.name()]; 675 shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); 676 const Node* input0_node; 677 TF_CHECK_OK(node.input_node(0, &input0_node)); 678 CHECK_NOTNULL(input0_node); 679 std::vector<TensorShape> shapes; 680 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 681 input0_node->attrs(), nullptr, &shapes); 682 CHECK_EQ(1, shapes.size()) << "Output size should be 1."; 683 const int const_val_id = 684 RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs()); 685 std::vector<int> extra_inputs{const_val_id}; 686 // TODO(satok): Set correct data type if it's given. 687 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); 688 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) 689 << "Op " << node.type_string() << " not found in map(id = " << op_type_id 690 << ")"; 691 bool keep_dims = false; 692 int padding_id = PADDING_NA_ID; 693 if (context->GetAttr(KEEP_DIMS_ATTR_NAME, &keep_dims).ok()) { 694 padding_id = keep_dims ? Padding::SAME : Padding::VALID; 695 } 696 697 AppendNodeParamsWithIoParams( 698 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 699 padding_id, node.num_inputs(), extra_inputs, node.num_outputs(), 700 true /* append_input */, true /* append_output */); 701 } 702 703 void GraphTransferer::RegisterPadNode( 704 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 705 const ShapeRefiner& shape_refiner, const Node& node) { 706 static constexpr int PAD_WIDTH = 4; 707 static constexpr int PAD_HEIGHT = 2; 708 VLOG(1) << "Register generic node: " << node.name(); 709 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 710 const int id = node_name_to_id_cache_map_[node.name()]; 711 712 // TODO(satok): Set correct data type if it's given. 713 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); 714 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); 715 716 CHECK_EQ(2, node.num_inputs()); 717 718 GraphTransferInfo::NodeInputInfo& node_input_info = 719 *graph_transfer_info_.add_node_input_info(); 720 node_input_info.set_node_id(id); 721 722 AddNodeInputByInputIndex(node, 0, &node_input_info); 723 724 const Edge* edge = nullptr; 725 TF_CHECK_OK(node.input_edge(1, &edge)); 726 const Node* input_node = edge->src(); 727 CHECK_NOTNULL(input_node); 728 CHECK(input_node->IsConstant()); 729 730 const TensorProto* tensor_proto = nullptr; 731 TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &tensor_proto)); 732 CHECK_NOTNULL(tensor_proto); 733 Tensor const_tensor; 734 TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor)); 735 CHECK_EQ(2, const_tensor.shape().dims()); 736 CHECK_EQ(PAD_HEIGHT, const_tensor.shape().dim_size(1)); 737 if (const_tensor.shape().dim_size(0) == PAD_WIDTH) { 738 AddNodeInputByInputIndex(node, 1, &node_input_info); 739 } else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) { 740 const int width = const_tensor.shape().dim_size(0); 741 const TensorProto* proto = nullptr; 742 TF_CHECK_OK(GetNodeAttr(input_node->attrs(), "value", &proto)); 743 Tensor const_tensor; 744 TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor)); 745 CHECK_EQ(DT_INT32, const_tensor.dtype()); 746 // reshape tensor input to be rank 4. 747 // TODO(satok): Never assume rank is 4. 748 Tensor new_const_tensor(const_tensor.dtype(), TensorShape{4, 2}); 749 for (int i = 0; i < PAD_HEIGHT; ++i) { 750 for (int j = 0; j < PAD_WIDTH; ++j) { 751 if (j < PAD_WIDTH - width) { 752 new_const_tensor.matrix<int32>()(j, i) = 0; 753 } else { 754 new_const_tensor.matrix<int32>()(j, i) = 755 const_tensor.matrix<int32>()(j - (PAD_WIDTH - width), i); 756 } 757 } 758 } 759 760 const int id = RegisterConstTensor( 761 new_const_tensor, 762 strings::StrCat(input_node->name(), "_", node.name(), "_1")); 763 764 GraphTransferInfo::NodeInput& node_input = 765 *node_input_info.add_node_input(); 766 node_input.set_node_id(id); 767 node_input.set_output_port(0); 768 } else { 769 LOG(FATAL); 770 } 771 772 AppendNodeParamsWithIoParams( 773 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 774 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), 775 false /* append_input */, true /* append_output */); 776 } 777 778 void GraphTransferer::RegisterInputNode( 779 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 780 const ShapeRefiner& shape_refiner, const Node& node) { 781 const string op_type = node.type_string(); 782 VLOG(1) << "Register input node: " << node.name() << ", " << op_type; 783 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 784 const int id = node_name_to_id_cache_map_[node.name()]; 785 // TODO(satok): Set correct data type if it's given. 786 const int op_type_id = ops_definitions.GetOpIdFor("INPUT", {}); 787 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) 788 << "Op" << node.name() << ", " << op_type << " is not supported," 789 << op_type_id; 790 AppendNodeParamsWithIoParams( 791 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 792 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), 793 true /* append_input */, true /* append_output */); 794 } 795 796 void GraphTransferer::RegisterFlattenNode( 797 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 798 const ShapeRefiner& shape_refiner, const Node& node) { 799 VLOG(1) << "Register flatten node: " << node.name(); 800 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 801 const int id = node_name_to_id_cache_map_[node.name()]; 802 // TODO(satok): Remove dependency to specific type 803 const string op_type = "FLATTEN"; 804 // TODO(satok): Set correct data type if it's given. 805 const int op_type_id = ops_definitions.GetOpIdFor(op_type, {}); 806 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); 807 808 AppendNodeParamsWithIoParams( 809 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 810 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), 811 true /* append_input */, true /* append_output */); 812 } 813 814 void GraphTransferer::RegisterGenericNode( 815 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 816 const ShapeRefiner& shape_refiner, const Node& node) { 817 VLOG(1) << "Register generic node: " << node.name(); 818 CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); 819 const int id = node_name_to_id_cache_map_[node.name()]; 820 // TODO(satok): Set correct data type if it's given. 821 const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); 822 CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); 823 824 AppendNodeParamsWithIoParams( 825 shape_refiner, node, node.name(), id, node.type_string(), op_type_id, 826 PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), 827 true /* append_input */, true /* append_output */); 828 } 829 830 // TODO(satok): Remove this function. 831 // TODO(satok): Remove only_register_const_node. 832 Status GraphTransferer::RegisterNodeIfAllInputsAreCached( 833 const IRemoteFusedGraphOpsDefinitions& ops_definitions, 834 const ShapeRefiner& shape_refiner, const Node& node, 835 const bool only_register_const_node, 836 const std::vector<std::pair<string, Tensor>>& input_node_info_list, 837 const std::vector<string>& output_node_names) { 838 if (only_register_const_node && !node.IsConstant()) { 839 return Status(); 840 } 841 CHECK(AreAllInputsCached(node)); 842 return RegisterNode(ops_definitions, shape_refiner, node, 843 input_node_info_list, output_node_names); 844 } 845 846 // CAVEAT: Append inputs and outputs params accordingly 847 void GraphTransferer::AppendNodeParams(const string& name, const int id, 848 const string& type, const int type_id, 849 const int padding, const int inputs_size, 850 const std::vector<int>& extra_inputs, 851 const int outputs_size) { 852 GraphTransferInfo::NodeInfo& node_info = 853 *graph_transfer_info_.add_node_info(); 854 node_info.set_name(name); 855 node_info.set_node_id(id); 856 node_info.set_type_name(type); 857 node_info.set_soc_op_id(type_id); 858 node_info.set_padding_id(padding); 859 node_info.set_input_count(inputs_size + 860 static_cast<int>(extra_inputs.size())); 861 node_info.set_output_count(static_cast<int>(outputs_size)); 862 } 863 864 void GraphTransferer::AddNodeInputByInputIndex( 865 const Node& node, const int idx, 866 GraphTransferInfo::NodeInputInfo* node_input_info) { 867 const Edge* edge = nullptr; 868 TF_CHECK_OK(node.input_edge(idx, &edge)); 869 const Node* input_node = edge->src(); 870 CHECK_NOTNULL(input_node); 871 const int port = edge->src_output(); 872 873 const std::string& op_name = input_node->name(); 874 CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name; 875 const int src_id = node_name_to_id_cache_map_[op_name]; 876 GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input(); 877 node_input.set_node_id(src_id); 878 node_input.set_output_port(port); 879 } 880 881 void GraphTransferer::AppendNodeInputParams( 882 const int id, const Node& node, const std::vector<int>& extra_inputs) { 883 VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs() 884 << ", " << extra_inputs.size(); 885 GraphTransferInfo::NodeInputInfo& node_input_info = 886 *graph_transfer_info_.add_node_input_info(); 887 node_input_info.set_node_id(id); 888 for (int i = 0; i < node.num_inputs(); ++i) { 889 AddNodeInputByInputIndex(node, i, &node_input_info); 890 } 891 for (const int extra_input : extra_inputs) { 892 GraphTransferInfo::NodeInput& node_input = 893 *node_input_info.add_node_input(); 894 node_input.set_node_id(extra_input); 895 node_input.set_output_port(0); 896 } 897 } 898 899 void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner, 900 const int id, const Node& node) { 901 VLOG(1) << "Append output params: " << node.name() << ", " 902 << node.num_outputs(); 903 GraphTransferInfo::NodeOutputInfo& node_output_info = 904 *graph_transfer_info_.add_node_output_info(); 905 node_output_info.set_node_id(id); 906 907 std::vector<DataType> data_types; 908 std::vector<TensorShape> shapes; 909 Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 910 node.attrs(), &data_types, &shapes); 911 912 for (int i = 0; i < node.num_outputs(); ++i) { 913 int data_size = -1; 914 const int output_index = i; 915 const DataType dt = node.output_type(output_index); 916 const size_t max_bytes_per_data = DataTypeSize(dt); 917 918 shape_inference::InferenceContext* context = 919 shape_refiner.GetContext(&node); 920 921 if (context != nullptr && context->ValueKnown(context->NumElements( 922 context->output(output_index)))) { 923 const shape_inference::DimensionHandle num_elements_dim = 924 context->NumElements(context->output(output_index)); 925 const int64 num_output_elements = context->Value(num_elements_dim); 926 data_size = max_bytes_per_data * num_output_elements; 927 if (status.ok()) { 928 TF_CHECK_OK(status); 929 CHECK_EQ(shapes.at(i).num_elements(), num_output_elements); 930 } 931 } else { 932 TF_CHECK_OK(status); 933 // Use attribute attached to node 934 data_size = max_bytes_per_data * shapes.at(i).num_elements(); 935 } 936 CHECK_GE(data_size, 0); 937 node_output_info.add_max_byte_size(data_size); 938 } 939 } 940 941 void GraphTransferer::AppendNodeParamsWithIoParams( 942 const ShapeRefiner& shape_refiner, const Node& node, const string& name, 943 const int id, const string& type, const int type_id, const int padding, 944 const int inputs_size, const std::vector<int>& extra_inputs, 945 const int outputs_size, const bool append_input_params, 946 const bool append_output_params) { 947 VLOG(1) << "Append node with io params: " << node.name(); 948 if (append_input_params) { 949 AppendNodeInputParams(id, node, extra_inputs); 950 } 951 if (append_output_params) { 952 AppendNodeOutputParams(shape_refiner, id, node); 953 } 954 AppendNodeParams(name, id, type, type_id, padding, inputs_size, extra_inputs, 955 outputs_size); 956 } 957 958 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE> 959 GraphTransferer::BuildShapeArray( 960 const shape_inference::ShapeHandle& shape_handle, 961 shape_inference::InferenceContext* context) { 962 switch (context->Rank(shape_handle)) { 963 case 0: 964 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}}; 965 case 1: 966 return std::array<int64, SHAPE_ARRAY_SIZE>{ 967 {1, 1, 1, context->Value(context->Dim(shape_handle, 0))}}; 968 case 2: 969 return std::array<int64, SHAPE_ARRAY_SIZE>{ 970 {1, 1, context->Value(context->Dim(shape_handle, 0)), 971 context->Value(context->Dim(shape_handle, 1))}}; 972 case 3: 973 return std::array<int64, SHAPE_ARRAY_SIZE>{ 974 {1, context->Value(context->Dim(shape_handle, 0)), 975 context->Value(context->Dim(shape_handle, 1)), 976 context->Value(context->Dim(shape_handle, 2))}}; 977 case 4: 978 return std::array<int64, SHAPE_ARRAY_SIZE>{ 979 {context->Value(context->Dim(shape_handle, 0)), 980 context->Value(context->Dim(shape_handle, 1)), 981 context->Value(context->Dim(shape_handle, 2)), 982 context->Value(context->Dim(shape_handle, 3))}}; 983 default: 984 // TODO(satok): Support more ranks? 985 LOG(FATAL); 986 return std::array<int64, SHAPE_ARRAY_SIZE>(); 987 } 988 } 989 990 /* static */ std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE> 991 GraphTransferer::ToTensorShapeArray(const TensorShape& shape) { 992 switch (shape.dims()) { 993 case 0: 994 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, 1}}; 995 case 1: 996 return std::array<int64, SHAPE_ARRAY_SIZE>{{1, 1, 1, shape.dim_size(0)}}; 997 case 2: 998 return std::array<int64, SHAPE_ARRAY_SIZE>{ 999 {1, 1, shape.dim_size(0), shape.dim_size(1)}}; 1000 case 3: 1001 return std::array<int64, SHAPE_ARRAY_SIZE>{ 1002 {1, shape.dim_size(0), shape.dim_size(1), shape.dim_size(2)}}; 1003 case 4: 1004 return std::array<int64, SHAPE_ARRAY_SIZE>{ 1005 {shape.dim_size(0), shape.dim_size(1), shape.dim_size(2), 1006 shape.dim_size(3)}}; 1007 default: 1008 // TODO(satok): Support more ranks? 1009 LOG(FATAL); 1010 return std::array<int64, SHAPE_ARRAY_SIZE>(); 1011 } 1012 } 1013 1014 /* static */ string GraphTransferer::ToPaddingDebugString(const int padding) { 1015 switch (padding) { 1016 case 0: 1017 return "NN_PAD_NA"; 1018 case Padding::VALID: 1019 return "NN_PAD_VALID"; 1020 case Padding::SAME: 1021 return "NN_PAD_SAME"; 1022 default: 1023 LOG(FATAL); 1024 return ""; 1025 } 1026 } 1027 1028 GraphTransferer::TransferParamsComparator::TransferParamsComparator( 1029 const std::unordered_map<int, std::unordered_set<int>>& dep_map) 1030 : dependency_map_(dep_map) {} 1031 1032 bool GraphTransferer::TransferParamsComparator::operator()( 1033 const GraphTransferInfo::NodeInfo& obj0, 1034 const GraphTransferInfo::NodeInfo& obj1) { 1035 const int node_id0 = obj0.node_id(); 1036 const int node_id1 = obj1.node_id(); 1037 bool obj0_uses_obj1 = false; 1038 if (dependency_map_.count(node_id0) > 0) { 1039 obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0; 1040 } 1041 bool obj1_uses_obj0 = false; 1042 if (dependency_map_.count(node_id1) > 0) { 1043 obj1_uses_obj0 = dependency_map_.at(node_id1).count(node_id0) > 0; 1044 } 1045 CHECK(!obj0_uses_obj1 || !obj1_uses_obj0); 1046 if (obj0_uses_obj1) { 1047 return false; 1048 } else if (obj1_uses_obj0) { 1049 return true; 1050 } 1051 // If there is no dependency between two nodes, it expects that 1052 // the execution order follows node id order. 1053 return node_id0 < node_id1; 1054 } 1055 1056 /* static */ void GraphTransferer::FillDependencyRec( 1057 const int node_id, 1058 std::unordered_map<int, std::unordered_set<int>>& dep_map, 1059 std::unordered_set<int>& completed) { 1060 if (dep_map.count(node_id) == 0 || dep_map.at(node_id).empty() || 1061 completed.count(node_id) == 1) { 1062 return; 1063 } 1064 CHECK_EQ(dep_map.count(node_id), 1); 1065 1066 // Complete children's dependency map 1067 for (int child_node_id : dep_map.at(node_id)) { 1068 CHECK(child_node_id != node_id); 1069 if (completed.count(child_node_id) != 0) { 1070 continue; 1071 } 1072 FillDependencyRec(child_node_id, dep_map, completed); 1073 } 1074 1075 // Find additional depending ids 1076 std::vector<int> depending_ids; 1077 for (int child_node_id : dep_map.at(node_id)) { 1078 if (dep_map.count(child_node_id) == 0) { 1079 continue; 1080 } 1081 for (int depending_id : dep_map.at(child_node_id)) { 1082 depending_ids.emplace_back(depending_id); 1083 } 1084 } 1085 1086 // Insert additional depending ids 1087 for (int depending_id : depending_ids) { 1088 if (dep_map.at(node_id).count(depending_id) == 0) { 1089 dep_map.at(node_id).emplace(depending_id); 1090 } 1091 } 1092 1093 // DP: Record completed node id 1094 completed.emplace(node_id); 1095 } 1096 1097 /* static */ Status GraphTransferer::MakeTensorFromProto( 1098 const TensorProto& tensor_proto, Tensor* tensor) { 1099 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { 1100 Tensor parsed(tensor_proto.dtype()); 1101 if (parsed.FromProto(cpu_allocator(), tensor_proto)) { 1102 *tensor = parsed; 1103 return Status::OK(); 1104 } 1105 } 1106 return errors::InvalidArgument("Cannot parse tensor from proto: ", 1107 tensor_proto.DebugString()); 1108 } 1109 1110 void GraphTransferer::ClearCache() { 1111 node_name_cache_list_.clear(); 1112 node_name_to_id_cache_map_.clear(); 1113 } 1114 1115 void GraphTransferer::DumpNodeTransferParams() const { 1116 LOG(INFO) << "*** Const Nodes ***"; 1117 for (const GraphTransferInfo::ConstNodeInfo& params : 1118 graph_transfer_info_.const_node_info()) { 1119 // TODO(satok): Stop assuming shape size is 4. 1120 CHECK_EQ(params.shape_size(), 4); 1121 LOG(INFO) << "[ " << params.node_id() << " \"" << params.name() 1122 << "\" (Const)"; 1123 LOG(INFO) << " shape: " << params.shape(0) << params.shape(1) 1124 << params.shape(2) << params.shape(3); 1125 LOG(INFO) << " data_name: " 1126 << (params.data().length() <= 0 1127 ? "" 1128 : DATA_NODE_PREFIX + ToString(params.node_id())); 1129 LOG(INFO) << " data_size: " << params.data().length() << " bytes" 1130 << " ]"; 1131 } 1132 LOG(INFO) << "******\n"; 1133 LOG(INFO) << "*** Op Nodes ***"; 1134 for (const GraphTransferInfo::NodeInfo& params : 1135 graph_transfer_info_.node_info()) { 1136 LOG(INFO) << "[ " << params.node_id() << " \"" << params.name(); 1137 LOG(INFO) << " type: " << params.type_name(); 1138 LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id()); 1139 LOG(INFO) << " inputs: " << INPUTS_NODE_PREFIX + ToString(params.node_id()) 1140 << ", size = " << params.input_count(); 1141 LOG(INFO) << " outputs: " 1142 << (params.output_count() <= 0 1143 ? NULL_OUTPUT_NAME 1144 : (OUTPUTS_NODE_PREFIX + ToString(params.node_id()))) 1145 << ", size = " << params.output_count() << " ]"; 1146 } 1147 LOG(INFO) << "******\n"; 1148 LOG(INFO) << "*** Node input params ***"; 1149 for (const GraphTransferInfo::NodeInputInfo& params : 1150 graph_transfer_info_.node_input_info()) { 1151 LOG(INFO) << "[ " << params.node_id() << " ]"; 1152 for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) { 1153 LOG(INFO) << " src node id = " << node_input.node_id() 1154 << ", output port = " << node_input.output_port(); 1155 } 1156 } 1157 LOG(INFO) << "******\n"; 1158 LOG(INFO) << "*** Node output params ***"; 1159 for (const GraphTransferInfo::NodeOutputInfo& params : 1160 graph_transfer_info_.node_output_info()) { 1161 LOG(INFO) << "[ " << params.node_id() << " ]"; 1162 for (const int max_size : params.max_byte_size()) { 1163 LOG(INFO) << " max_size = " << max_size; 1164 } 1165 } 1166 LOG(INFO) << "******\n"; 1167 } 1168 1169 void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const { 1170 for (const GraphTransferInfo::ConstNodeInfo& params : 1171 graph_transfer_info_.const_node_info()) { 1172 std::stringstream sstream; 1173 // TODO(satok): Stop assuming shape size is 4. 1174 CHECK_EQ(params.shape_size(), 4); 1175 sstream << "---(CONST) [" << std::hex << params.node_id() << std::dec << "," 1176 << params.shape(0) << "," << params.shape(1) << "," 1177 << params.shape(2) << "," << params.shape(3) << "," 1178 << (params.data().length() <= 0 1179 ? "" 1180 : DATA_NODE_PREFIX + ToString(params.node_id())) 1181 << "," << params.data().length() << "," << params.name() << "]"; 1182 LOG(INFO) << sstream.str(); 1183 } 1184 LOG(INFO) << "Const node count = " 1185 << graph_transfer_info_.const_node_info_size(); 1186 for (const GraphTransferInfo::NodeInfo& params : 1187 graph_transfer_info_.node_info()) { 1188 std::stringstream sstream; 1189 sstream << "---(OP) [" << params.name().c_str() << "," << std::hex 1190 << params.node_id() << std::dec << "," << params.soc_op_id() << "," 1191 << ToPaddingDebugString(params.padding_id()) << "," 1192 << INPUTS_NODE_PREFIX + ToString(params.node_id()) << "," 1193 << params.input_count() << "," 1194 << (params.output_count() <= 0 1195 ? NULL_OUTPUT_NAME 1196 : (OUTPUTS_NODE_PREFIX + ToString(params.node_id()))) 1197 << "," << params.output_count() << "," << params.type_name() << "]"; 1198 LOG(INFO) << sstream.str(); 1199 } 1200 LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size(); 1201 for (const GraphTransferInfo::NodeInputInfo& params : 1202 graph_transfer_info_.node_input_info()) { 1203 std::stringstream sstream; 1204 sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec; 1205 for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) { 1206 sstream << "," << std::hex << node_input.node_id() << std::dec << "," 1207 << node_input.output_port(); 1208 } 1209 sstream << "]"; 1210 LOG(INFO) << sstream.str(); 1211 } 1212 LOG(INFO) << "Input params count = " 1213 << graph_transfer_info_.node_input_info_size(); 1214 for (const GraphTransferInfo::NodeOutputInfo& params : 1215 graph_transfer_info_.node_output_info()) { 1216 std::stringstream sstream; 1217 sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec; 1218 for (const int max_size : params.max_byte_size()) { 1219 sstream << "," << max_size; 1220 } 1221 sstream << "]"; 1222 LOG(INFO) << sstream.str(); 1223 } 1224 LOG(INFO) << "Output params count = " 1225 << graph_transfer_info_.node_output_info_size(); 1226 } 1227 1228 } // namespace tensorflow 1229