Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <memory>
     16 #include <string>
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     21 #include "tensorflow/contrib/lite/toco/model.h"
     22 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 
     25 namespace toco {
     26 
     27 bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
     28   const auto switch_it = model->operators.begin() + op_index;
     29   const auto* switch_op = switch_it->get();
     30   if (switch_op->type != OperatorType::kTensorFlowSwitch) {
     31     return false;
     32   }
     33 
     34   CHECK_EQ(switch_op->inputs.size(), 2);
     35   CHECK_EQ(switch_op->outputs.size(), 2);
     36   const string& predicate_name = switch_op->inputs[1];
     37   // If the predicate array hasn't been resolved to a constant yet,
     38   // we need to yield.
     39   if (!IsConstantParameterArray(*model, predicate_name)) {
     40     AddMessageF(
     41         "Waiting for the boolean predicate of %s to be resolved to a constant",
     42         LogName(*switch_op));
     43     return false;
     44   }
     45 
     46   // The predicate should be boolean, and should consist of a single value.
     47   const auto& predicate_array = model->GetArray(predicate_name);
     48   CHECK(predicate_array.data_type == ArrayDataType::kBool);
     49   for (const auto& dim : predicate_array.shape().dims()) {
     50     CHECK_EQ(dim, 1);
     51   }
     52 
     53   // Obtain the predicate boolean value.
     54   const auto& predicate_data =
     55       predicate_array.GetBuffer<ArrayDataType::kBool>().data;
     56   CHECK_EQ(predicate_data.size(), 1);
     57   const bool predicate_value = predicate_data[0];
     58 
     59   // From the TensorFlow docs on .switch() in
     60   // third_party/tensorflow/python/ops/control_flow_ops.py
     61   //
     62   //    If `pred` is false, the `data` input is forwarded to the first output.
     63   //    Otherwise, the data goes to the second output.
     64   //
     65   // Note that this comment used to say the opposite and was recently fixed:
     66   // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c
     67   const int selected_output_index = predicate_value ? 1 : 0;
     68   const int nonselected_output_index = predicate_value ? 0 : 1;
     69 
     70   // Update the edges of the graph ahead of removing the node:
     71   // edges that were pointing to the selected output, should instead
     72   // point to the input of the Switch node.
     73   for (const auto& other_op : model->operators) {
     74     for (auto& input : other_op->inputs) {
     75       if (input == switch_op->outputs[selected_output_index]) {
     76         input = switch_op->inputs[0];
     77       }
     78     }
     79   }
     80 
     81   // There remains to handle the edges that were pointing to the nonselected
     82   // output. We will just discard those edges. Concretely, at the moment,
     83   // our only examples of graphs with Switch nodes have them feeding into Merge
     84   // nodes, so what we're saying here is that we'll make the convention,
     85   // in our toco internal representation, that Merge nodes with only 1 input
     86   // are Merge nodes that have been resolved already and should be have as
     87   // Identity nodes, simply forwarding their input.
     88   //
     89   for (const auto& other_op : model->operators) {
     90     auto input_it = other_op->inputs.begin();
     91     while (input_it != other_op->inputs.end()) {
     92       if (*input_it == switch_op->outputs[nonselected_output_index]) {
     93         // Let us guard our assumption that only Merge nodes consume the outputs
     94         // of Switch nodes:
     95         CHECK(other_op->type == OperatorType::kTensorFlowMerge);
     96         input_it = other_op->inputs.erase(input_it);
     97       } else {
     98         ++input_it;
     99       }
    100     }
    101   }
    102 
    103   // Remove the output arrays if they are now unused.
    104   for (int i = 0; i < 2; i++) {
    105     if (!GetOpWithInput(*model, switch_op->outputs[i])) {
    106       model->EraseArray(switch_op->outputs[i]);
    107     }
    108   }
    109   // Remove input arrays if they are only used by the switch itself and aren't
    110   // the output of another op (will get handled by RemoveUnusedOp in that case).
    111   for (const auto& input : switch_op->inputs) {
    112     if (CountOpsWithInput(*model, input) == 1 &&
    113         !GetOpWithOutput(*model, input)) {
    114       model->EraseArray(input);
    115     }
    116   }
    117   // Remove the switch node itself.
    118   AddMessageF("Removing already-resolved %s", LogName(*switch_op));
    119   model->operators.erase(switch_it);
    120   return true;
    121 }
    122 
    123 }  // namespace toco
    124