Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <memory>
     16 #include <string>
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     21 #include "tensorflow/contrib/lite/toco/model.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace toco {
     25 
     26 namespace {
     27 void SetDataTypeForAllOutputs(Model* model, Operator* op,
     28                               ArrayDataType data_type) {
     29   for (const auto& output : op->outputs) {
     30     model->GetArray(output).data_type = data_type;
     31   }
     32 }
     33 }  // namespace
     34 
     35 bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
     36   auto it = model->operators.begin() + op_index;
     37   auto* op = it->get();
     38 
     39   // If the data type of some input is unknown, we need to yield.
     40   for (const auto& input : op->inputs) {
     41     if (!model->IsOptionalArray(input) &&
     42         model->GetArray(input).data_type == ArrayDataType::kNone) {
     43       return false;
     44     }
     45   }
     46   // Record data types of output before processing, so we can see at the
     47   // end if we changed anything, and return the correct boolean value.
     48   std::unordered_map<string, ArrayDataType> old_output_data_types;
     49   for (const auto& output : op->outputs) {
     50     old_output_data_types[output] = model->GetArray(output).data_type;
     51   }
     52   // Do the actual output data types propagation.
     53   if (op->type == OperatorType::kDequantize ||
     54       op->type == OperatorType::kResizeBilinear) {
     55     // These operators unconditionally produce float outputs
     56     SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
     57   } else if (op->type == OperatorType::kTensorFlowLess ||
     58              op->type == OperatorType::kTensorFlowLessEqual ||
     59              op->type == OperatorType::kTensorFlowGreater ||
     60              op->type == OperatorType::kTensorFlowGreaterEqual) {
     61     // These operators unconditionally produce bool outputs
     62     SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
     63   } else if (op->type == OperatorType::kRank ||
     64              op->type == OperatorType::kTensorFlowShape) {
     65     // These operators only produce int32 outputs.
     66     SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
     67   } else if (op->type == OperatorType::kTensorFlowSplit ||
     68              op->type == OperatorType::kTensorFlowConcat ||
     69              op->type == OperatorType::kFill) {
     70     // These operators produce an output with the same type as their 2nd input
     71     CHECK_GE(op->inputs.size(), 2);
     72     const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
     73     SetDataTypeForAllOutputs(model, op, data_type);
     74   } else if (op->type == OperatorType::kCast) {
     75     // Data type of the Cast op is specified.
     76     CHECK_EQ(op->outputs.size(), 1);
     77     auto* cast_op = static_cast<CastOperator*>(op);
     78     model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
     79   } else if (op->type == OperatorType::kArgMax) {
     80     // Data type of the ArgMax op is specified.
     81     CHECK_EQ(op->outputs.size(), 1);
     82     auto* argmax_op = static_cast<ArgMaxOperator*>(op);
     83     model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
     84   } else if (op->type == OperatorType::kRange) {
     85     auto* range_op = static_cast<RangeOperator*>(op);
     86     // Output type of the Range op can be set via an attribute
     87     ArrayDataType data_type;
     88     if (range_op->dtype != ArrayDataType::kNone) {
     89       // Use the type if specified
     90       data_type = range_op->dtype;
     91     } else {
     92       // Otherwise use the first input
     93       CHECK_GE(op->inputs.size(), 1);
     94       data_type = model->GetArray(op->inputs[0]).data_type;
     95     }
     96     CHECK_EQ(op->outputs.size(), 1);
     97     SetDataTypeForAllOutputs(model, op, data_type);
     98   } else if (op->type == OperatorType::kTensorFlowUnsupported) {
     99     auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
    100     if (unsupported_op->output_data_types.size() != op->outputs.size()) {
    101       return false;
    102     }
    103     for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
    104       auto output = op->outputs[i];
    105       auto data_type = unsupported_op->output_data_types[i];
    106       model->GetArray(output).data_type = data_type;
    107     }
    108   } else if (op->type == OperatorType::kExpandDims) {
    109     // Yield on ExpandDim until it is converted to Reshape
    110     return false;
    111   } else {
    112     // These operators produce outputs with the same type as their 1st input
    113     CHECK_GT(op->inputs.size(), 0);
    114     const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
    115     SetDataTypeForAllOutputs(model, op, data_type);
    116   }
    117   // Return true if any output data type changed, false if none changed.
    118   for (const auto& output : op->outputs) {
    119     if (old_output_data_types[output] != model->GetArray(output).data_type) {
    120       return true;
    121     }
    122   }
    123   return false;
    124 }
    125 
    126 }  // namespace toco
    127