1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 17 #include "tensorflow/core/framework/cost_graph.pb.h" 18 #include "tensorflow/core/framework/tensor_shape.pb.h" 19 #include "tensorflow/core/framework/types.h" 20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" 21 #include "tensorflow/core/grappler/costs/virtual_scheduler.h" 22 23 namespace tensorflow { 24 namespace grappler { 25 26 VirtualCluster::VirtualCluster( 27 const std::unordered_map<string, DeviceProperties>& devices) 28 : Cluster(0), 29 node_estimator_(new OpLevelCostEstimator()), 30 node_manager_(new FirstReadyManager()) { 31 devices_ = devices; 32 } 33 34 VirtualCluster::VirtualCluster( 35 const std::unordered_map<string, DeviceProperties>& devices, 36 OpLevelCostEstimator* node_estimator, ReadyNodeManager* node_manager) 37 : Cluster(0), node_estimator_(node_estimator), node_manager_(node_manager) { 38 devices_ = devices; 39 } 40 VirtualCluster::~VirtualCluster() {} 41 42 Status VirtualCluster::Provision() { return Status::OK(); } 43 44 Status VirtualCluster::Initialize(const GrapplerItem& item) { 45 return Status::OK(); 46 } 47 48 Status VirtualCluster::Run(const GraphDef& graph, 49 const std::vector<std::pair<string, Tensor>>& feed, 50 const std::vector<string>& fetch, 51 RunMetadata* metadata) { 52 // Initialize a virtual scheduler to process the graph. Make sure to use 53 // static shape inference to prevent the schedulrer from calling the Run 54 // method on the cluster, and create an infinite loop. 55 GrapplerItem item; 56 item.graph = graph; 57 item.feed = feed; 58 item.fetch = fetch; 59 VirtualScheduler scheduler(&item, true, this, node_manager_.get()); 60 TF_RETURN_IF_ERROR(scheduler.Init()); 61 62 if (metadata) { 63 metadata->clear_step_stats(); 64 metadata->clear_cost_graph(); 65 metadata->clear_partition_graphs(); 66 } 67 68 Costs node_costs; 69 do { 70 OpContext op_context = scheduler.GetCurrNode(); 71 node_costs = node_estimator_->PredictCosts(op_context); 72 if (metadata) { 73 CostGraphDef::Node* cost_node = 74 metadata->mutable_cost_graph()->add_node(); 75 const string& op_name = op_context.name; 76 cost_node->set_name(op_name); 77 cost_node->set_device(op_context.device_name); 78 cost_node->set_compute_cost( 79 node_costs.execution_time.asMicroSeconds().count()); 80 cost_node->set_compute_time( 81 node_costs.compute_time.asMicroSeconds().count()); 82 cost_node->set_memory_time( 83 node_costs.memory_time.asMicroSeconds().count()); 84 for (const auto& output : op_context.op_info.outputs()) { 85 auto output_info = cost_node->add_output_info(); 86 output_info->set_dtype(output.dtype()); 87 *output_info->mutable_shape() = output.shape(); 88 89 int64 size = DataTypeSize(output.dtype()); 90 for (const auto& dim : output.shape().dim()) { 91 size *= std::max<int64>(1, dim.size()); 92 } 93 output_info->set_size(size); 94 } 95 } 96 } while (scheduler.MarkCurrNodeExecuted(node_costs)); 97 98 if (metadata) { 99 scheduler.Summary(metadata); 100 } 101 102 const std::unordered_map<string, DeviceProperties>& device = GetDevices(); 103 std::unordered_map<string, int64> peak_mem_usage = 104 scheduler.GetPeakMemoryUsage(); 105 for (const auto& mem_usage : peak_mem_usage) { 106 const string& device_name = mem_usage.first; 107 auto it = device.find(device_name); 108 if (it == device.end()) { 109 // It's probably the fake send/recv device. Eventually we'll need to 110 // remove this fake device to ensure proper memory accounting for 111 // multi-device settings. 112 continue; 113 } 114 const DeviceProperties& dev = it->second; 115 if (dev.memory_size() <= 0) { 116 // Available device memory unknown 117 continue; 118 } 119 int64 peak_mem = mem_usage.second; 120 if (peak_mem >= dev.memory_size()) { 121 return errors::ResourceExhausted( 122 "Graph requires ", peak_mem, " bytes of memory on device ", 123 device_name, " to run ", " but device only has ", dev.memory_size(), 124 " available."); 125 } 126 } 127 128 return Status::OK(); 129 } 130 131 } // namespace grappler 132 } // namespace tensorflow 133