Home | History | Annotate | Download | only in clusters
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
     17 #include "tensorflow/core/framework/cost_graph.pb.h"
     18 #include "tensorflow/core/framework/tensor_shape.pb.h"
     19 #include "tensorflow/core/framework/types.h"
     20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
     21 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
     22 
     23 namespace tensorflow {
     24 namespace grappler {
     25 
     26 VirtualCluster::VirtualCluster(
     27     const std::unordered_map<string, DeviceProperties>& devices)
     28     : Cluster(0),
     29       node_estimator_(new OpLevelCostEstimator()),
     30       node_manager_(new FirstReadyManager()) {
     31   devices_ = devices;
     32 }
     33 
     34 VirtualCluster::VirtualCluster(
     35     const std::unordered_map<string, DeviceProperties>& devices,
     36     OpLevelCostEstimator* node_estimator, ReadyNodeManager* node_manager)
     37     : Cluster(0), node_estimator_(node_estimator), node_manager_(node_manager) {
     38   devices_ = devices;
     39 }
     40 VirtualCluster::~VirtualCluster() {}
     41 
     42 Status VirtualCluster::Provision() { return Status::OK(); }
     43 
     44 Status VirtualCluster::Initialize(const GrapplerItem& item) {
     45   return Status::OK();
     46 }
     47 
     48 Status VirtualCluster::Run(const GraphDef& graph,
     49                            const std::vector<std::pair<string, Tensor>>& feed,
     50                            const std::vector<string>& fetch,
     51                            RunMetadata* metadata) {
     52   // Initialize a virtual scheduler to process the graph. Make sure to use
     53   // static shape inference to prevent the schedulrer from calling the Run
     54   // method on the cluster, and create an infinite loop.
     55   GrapplerItem item;
     56   item.graph = graph;
     57   item.feed = feed;
     58   item.fetch = fetch;
     59   VirtualScheduler scheduler(&item, true, this, node_manager_.get());
     60   TF_RETURN_IF_ERROR(scheduler.Init());
     61 
     62   if (metadata) {
     63     metadata->clear_step_stats();
     64     metadata->clear_cost_graph();
     65     metadata->clear_partition_graphs();
     66   }
     67 
     68   Costs node_costs;
     69   do {
     70     OpContext op_context = scheduler.GetCurrNode();
     71     node_costs = node_estimator_->PredictCosts(op_context);
     72     if (metadata) {
     73       CostGraphDef::Node* cost_node =
     74           metadata->mutable_cost_graph()->add_node();
     75       const string& op_name = op_context.name;
     76       cost_node->set_name(op_name);
     77       cost_node->set_device(op_context.device_name);
     78       cost_node->set_compute_cost(
     79           node_costs.execution_time.asMicroSeconds().count());
     80       cost_node->set_compute_time(
     81           node_costs.compute_time.asMicroSeconds().count());
     82       cost_node->set_memory_time(
     83           node_costs.memory_time.asMicroSeconds().count());
     84       for (const auto& output : op_context.op_info.outputs()) {
     85         auto output_info = cost_node->add_output_info();
     86         output_info->set_dtype(output.dtype());
     87         *output_info->mutable_shape() = output.shape();
     88 
     89         int64 size = DataTypeSize(output.dtype());
     90         for (const auto& dim : output.shape().dim()) {
     91           size *= std::max<int64>(1, dim.size());
     92         }
     93         output_info->set_size(size);
     94       }
     95     }
     96   } while (scheduler.MarkCurrNodeExecuted(node_costs));
     97 
     98   if (metadata) {
     99     scheduler.Summary(metadata);
    100   }
    101 
    102   const std::unordered_map<string, DeviceProperties>& device = GetDevices();
    103   std::unordered_map<string, int64> peak_mem_usage =
    104       scheduler.GetPeakMemoryUsage();
    105   for (const auto& mem_usage : peak_mem_usage) {
    106     const string& device_name = mem_usage.first;
    107     auto it = device.find(device_name);
    108     if (it == device.end()) {
    109       // It's probably the fake send/recv device. Eventually we'll need to
    110       // remove this fake device to ensure proper memory accounting for
    111       // multi-device settings.
    112       continue;
    113     }
    114     const DeviceProperties& dev = it->second;
    115     if (dev.memory_size() <= 0) {
    116       // Available device memory unknown
    117       continue;
    118     }
    119     int64 peak_mem = mem_usage.second;
    120     if (peak_mem >= dev.memory_size()) {
    121       return errors::ResourceExhausted(
    122           "Graph requires ", peak_mem, " bytes of memory on device ",
    123           device_name, " to run ", " but device only has ", dev.memory_size(),
    124           " available.");
    125     }
    126   }
    127 
    128   return Status::OK();
    129 }
    130 
    131 }  // namespace grappler
    132 }  // namespace tensorflow
    133