1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <memory> 16 #include <string> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/contrib/lite/toco/model.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace toco { 25 26 namespace { 27 void SetDataTypeForAllOutputs(Model* model, Operator* op, 28 ArrayDataType data_type) { 29 for (const auto& output : op->outputs) { 30 model->GetArray(output).data_type = data_type; 31 } 32 } 33 } // namespace 34 35 bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { 36 auto it = model->operators.begin() + op_index; 37 auto* op = it->get(); 38 39 // If the data type of some input is unknown, we need to yield. 40 for (const auto& input : op->inputs) { 41 if (!model->IsOptionalArray(input) && 42 model->GetArray(input).data_type == ArrayDataType::kNone) { 43 return false; 44 } 45 } 46 // Record data types of output before processing, so we can see at the 47 // end if we changed anything, and return the correct boolean value. 48 std::unordered_map<string, ArrayDataType> old_output_data_types; 49 for (const auto& output : op->outputs) { 50 old_output_data_types[output] = model->GetArray(output).data_type; 51 } 52 // Do the actual output data types propagation. 53 if (op->type == OperatorType::kDequantize || 54 op->type == OperatorType::kResizeBilinear) { 55 // These operators unconditionally produce float outputs 56 SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); 57 } else if (op->type == OperatorType::kTensorFlowLess || 58 op->type == OperatorType::kTensorFlowLessEqual || 59 op->type == OperatorType::kTensorFlowGreater || 60 op->type == OperatorType::kTensorFlowGreaterEqual) { 61 // These operators unconditionally produce bool outputs 62 SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); 63 } else if (op->type == OperatorType::kRank || 64 op->type == OperatorType::kTensorFlowShape) { 65 // These operators only produce int32 outputs. 66 SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); 67 } else if (op->type == OperatorType::kTensorFlowSplit || 68 op->type == OperatorType::kTensorFlowConcat || 69 op->type == OperatorType::kFill) { 70 // These operators produce an output with the same type as their 2nd input 71 CHECK_GE(op->inputs.size(), 2); 72 const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; 73 SetDataTypeForAllOutputs(model, op, data_type); 74 } else if (op->type == OperatorType::kCast) { 75 // Data type of the Cast op is specified. 76 CHECK_EQ(op->outputs.size(), 1); 77 auto* cast_op = static_cast<CastOperator*>(op); 78 model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; 79 } else if (op->type == OperatorType::kArgMax) { 80 // Data type of the ArgMax op is specified. 81 CHECK_EQ(op->outputs.size(), 1); 82 auto* argmax_op = static_cast<ArgMaxOperator*>(op); 83 model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; 84 } else if (op->type == OperatorType::kRange) { 85 auto* range_op = static_cast<RangeOperator*>(op); 86 // Output type of the Range op can be set via an attribute 87 ArrayDataType data_type; 88 if (range_op->dtype != ArrayDataType::kNone) { 89 // Use the type if specified 90 data_type = range_op->dtype; 91 } else { 92 // Otherwise use the first input 93 CHECK_GE(op->inputs.size(), 1); 94 data_type = model->GetArray(op->inputs[0]).data_type; 95 } 96 CHECK_EQ(op->outputs.size(), 1); 97 SetDataTypeForAllOutputs(model, op, data_type); 98 } else if (op->type == OperatorType::kTensorFlowUnsupported) { 99 auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); 100 if (unsupported_op->output_data_types.size() != op->outputs.size()) { 101 return false; 102 } 103 for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { 104 auto output = op->outputs[i]; 105 auto data_type = unsupported_op->output_data_types[i]; 106 model->GetArray(output).data_type = data_type; 107 } 108 } else if (op->type == OperatorType::kExpandDims) { 109 // Yield on ExpandDim until it is converted to Reshape 110 return false; 111 } else { 112 // These operators produce outputs with the same type as their 1st input 113 CHECK_GT(op->inputs.size(), 0); 114 const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; 115 SetDataTypeForAllOutputs(model, op, data_type); 116 } 117 // Return true if any output data type changed, false if none changed. 118 for (const auto& output : op->outputs) { 119 if (old_output_data_types[output] != model->GetArray(output).data_type) { 120 return true; 121 } 122 } 123 return false; 124 } 125 126 } // namespace toco 127