Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
     17 #include <deque>
     18 #include "tensorflow/core/framework/attr_value.pb.h"
     19 #include "tensorflow/core/grappler/costs/graph_properties.h"
     20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
     21 #include "tensorflow/core/grappler/costs/virtual_placer.h"
     22 #include "tensorflow/core/grappler/op_types.h"
     23 #include "tensorflow/core/grappler/utils.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 namespace tensorflow {
     28 namespace grappler {
     29 
     30 static Costs::NanoSeconds PredictExecutionTime(
     31     const GraphProperties& properties, const OpLevelCostEstimator& estimator,
     32     const VirtualPlacer& placer, const NodeDef& node) {
     33   OpContext op_context;
     34   op_context.op_info.set_op(node.op());
     35   *op_context.op_info.mutable_attr() = node.attr();
     36 
     37   std::vector<OpInfo::TensorProperties> inputs =
     38       properties.GetInputProperties(node.name());
     39   for (auto& input : inputs) {
     40     op_context.op_info.add_inputs()->Swap(&input);
     41   }
     42 
     43   DeviceProperties device = placer.get_device(node);
     44   op_context.op_info.mutable_device()->Swap(&device);
     45 
     46   Costs::NanoSeconds estimate =
     47       estimator.PredictCosts(op_context).execution_time;
     48 
     49   // Make sure our estimates are at least one nanosecond per node.
     50   return std::max(estimate, Costs::NanoSeconds(1));
     51 }
     52 
     53 Status EstimateEarliestExecutionTimes(
     54     const GrapplerItem& item, const Cluster* cluster,
     55     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) {
     56   std::unordered_map<string, const NodeDef*> name_map;
     57   std::unordered_map<const NodeDef*, int> pending_inputs;
     58   std::deque<const NodeDef*> ready_nodes;
     59   for (const NodeDef& node : item.graph.node()) {
     60     name_map[node.name()] = &node;
     61     if (node.input_size() == 0) {
     62       ready_nodes.push_back(&node);
     63       (*completion_times)[&node] = 0;
     64     } else if (IsMerge(node)) {
     65       // Merge nodes are processed as soon as one of the input becomes
     66       // available.
     67       pending_inputs[&node] = 1;
     68     } else {
     69       pending_inputs[&node] = node.input_size();
     70     }
     71   }
     72 
     73   std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts;
     74   for (const NodeDef& node : item.graph.node()) {
     75     for (const string& input : node.input()) {
     76       string node_name = NodeName(input);
     77       auto it = name_map.find(node_name);
     78       if (it == name_map.end()) {
     79         return errors::InvalidArgument(
     80             strings::StrCat("Unknown input node ", input));
     81       }
     82       const NodeDef* fanin = it->second;
     83       fanouts[fanin].push_back(&node);
     84     }
     85   }
     86   name_map.clear();
     87 
     88   GraphProperties properties(item);
     89   TF_RETURN_IF_ERROR(properties.InferStatically(true));
     90   OpLevelCostEstimator estimator;
     91   VirtualPlacer placer(cluster);
     92 
     93   while (!ready_nodes.empty()) {
     94     const NodeDef* node = ready_nodes.front();
     95     ready_nodes.pop_front();
     96 
     97     Costs::NanoSeconds execution_time =
     98         PredictExecutionTime(properties, estimator, placer, *node);
     99     Costs::NanoSeconds completion_time =
    100         execution_time + (*completion_times)[node];
    101     (*completion_times)[node] = completion_time;
    102 
    103     for (const NodeDef* fanout : fanouts[node]) {
    104       int pending = pending_inputs[fanout];
    105       if (pending == 0) {
    106         // Already processed. Avoid going through loops more than once.
    107         continue;
    108       } else if (pending == 1) {
    109         ready_nodes.push_back(fanout);
    110       }
    111       pending_inputs[fanout]--;
    112 
    113       Costs::NanoSeconds ready_time =
    114           std::max(completion_time, (*completion_times)[fanout]);
    115       (*completion_times)[fanout] = ready_time;
    116     }
    117   }
    118 
    119   return Status::OK();
    120 }
    121 
    122 Status EstimateRequiredTimes(
    123     const GrapplerItem& item, const Cluster* cluster,
    124     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
    125         execution_times,
    126     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
    127   std::unordered_map<string, const NodeDef*> name_map;
    128   for (const NodeDef& node : item.graph.node()) {
    129     name_map[node.name()] = &node;
    130     (*required_times)[&node] = Costs::NanoSeconds::max();
    131   }
    132 
    133   std::unordered_map<const NodeDef*, int> pending_fanouts;
    134   for (const NodeDef& node : item.graph.node()) {
    135     for (const string& input : node.input()) {
    136       string node_name = NodeName(input);
    137       auto it = name_map.find(node_name);
    138       if (it == name_map.end()) {
    139         return errors::InvalidArgument(
    140             strings::StrCat("Unknown input node ", input));
    141       }
    142       const NodeDef* fanin = it->second;
    143       pending_fanouts[fanin] += 1;
    144     }
    145   }
    146   std::deque<const NodeDef*> ready_nodes;
    147   for (const NodeDef& node : item.graph.node()) {
    148     if (pending_fanouts[&node] == 0) {
    149       auto it = execution_times.find(&node);
    150       if (it != execution_times.end()) {
    151         (*required_times)[&node] = it->second;
    152       }
    153       ready_nodes.push_back(&node);
    154     }
    155   }
    156   GraphProperties properties(item);
    157   TF_RETURN_IF_ERROR(properties.InferStatically(true));
    158   OpLevelCostEstimator estimator;
    159   VirtualPlacer placer(cluster);
    160 
    161   while (!ready_nodes.empty()) {
    162     const NodeDef* node = ready_nodes.front();
    163     ready_nodes.pop_front();
    164 
    165     Costs::NanoSeconds execution_time =
    166         PredictExecutionTime(properties, estimator, placer, *node);
    167     Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
    168 
    169     for (const string& fanin_name : node->input()) {
    170       const NodeDef* fanin = name_map[NodeName(fanin_name)];
    171       (*required_times)[fanin] =
    172           std::min((*required_times)[fanin], required_time);
    173 
    174       int pending = pending_fanouts[fanin];
    175       if (pending == 0) {
    176         // Already processed. Avoid going through loops more than once.
    177         continue;
    178       } else if (pending == 1) {
    179         ready_nodes.push_back(fanin);
    180       }
    181       pending_fanouts[fanin]--;
    182     }
    183   }
    184 
    185   return Status::OK();
    186 }
    187 
    188 }  // end namespace grappler
    189 }  // end namespace tensorflow
    190