1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <memory> 16 #include <string> 17 #include <unordered_map> 18 #include <vector> 19 20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" 21 #include "tensorflow/lite/toco/model.h" 22 #include "tensorflow/lite/toco/tooling_util.h" 23 #include "tensorflow/core/platform/logging.h" 24 25 // This transformation rule tries to identify the PRelu structure generated by 26 // Keras, and convert it to a single op. 27 // 28 // The formula of PReLU is: 29 // f(x) = alpha * x for x < 0, f(x) = x for x >= 0. 30 // 31 // `x` is the input, and `alpha` is a trainable tensor which can be broadcasted 32 // to the shape of `x`. 33 // 34 // There's no native PRelu op in TensorFlow, so Keras generates the following 35 // structure which does the equivalent calculation: 36 // f(x) = Relu(x) + (-alpha * Relu(-x)) 37 // 38 // Practically, alpha is always a constant in the inference graph, and Toco have 39 // other graph transformations which fold the activation functions to other ops. 40 // Therefore, we're looking for the structure: 41 // 42 // f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu)) 43 44 namespace toco { 45 46 ::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index, 47 bool* modified) { 48 *modified = false; 49 const auto add_op_it = model->operators.begin() + op_index; 50 const auto* add_op = add_op_it->get(); 51 if (add_op == nullptr || add_op->type != OperatorType::kAdd || 52 add_op->inputs.size() != 2 || 53 add_op->fused_activation_function != FusedActivationFunctionType::kNone) { 54 return ::tensorflow::Status::OK(); 55 } 56 57 const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]); 58 if (relu_input_op == nullptr || relu_input_op->type != OperatorType::kRelu || 59 relu_input_op->inputs.size() != 1 || 60 relu_input_op->fused_activation_function != 61 FusedActivationFunctionType::kNone) { 62 return ::tensorflow::Status::OK(); 63 } 64 65 // TODO(ycling): Both Add and Mul are commutative. Support the case where 66 // the position of operands are exchanged. 67 const auto* mul_op = GetOpWithOutput(*model, add_op->inputs[1]); 68 if (mul_op == nullptr || mul_op->type != OperatorType::kMul || 69 mul_op->inputs.size() != 2 || 70 mul_op->fused_activation_function != FusedActivationFunctionType::kNone) { 71 return ::tensorflow::Status::OK(); 72 } 73 74 const auto neg_alpha_tensor_name = mul_op->inputs[0]; 75 76 const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]); 77 78 if (relu_neg_input_op == nullptr || 79 relu_neg_input_op->inputs.size() != 1) { 80 return ::tensorflow::Status::OK(); 81 } 82 83 const Operator* final_input_op; 84 if (relu_neg_input_op->type == OperatorType::kNeg && 85 relu_neg_input_op->fused_activation_function == 86 FusedActivationFunctionType::kRelu) { 87 // This detects a Neg op with fused Relu activation function. 88 final_input_op = relu_neg_input_op; 89 } else { 90 // This detects a Neg op followed by a separated Relu op. 91 const auto* neg_input_op = 92 GetOpWithOutput(*model, relu_neg_input_op->inputs[0]); 93 if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 || 94 relu_neg_input_op->type != OperatorType::kRelu || 95 relu_neg_input_op->fused_activation_function != 96 FusedActivationFunctionType::kNone) { 97 return ::tensorflow::Status::OK(); 98 } 99 final_input_op = neg_input_op; 100 } 101 102 if (relu_input_op->inputs[0] != final_input_op->inputs[0]) { 103 return ::tensorflow::Status::OK(); 104 } 105 106 const auto input_tensor_name = relu_input_op->inputs[0]; 107 const auto output_tensor_name = add_op->outputs[0]; 108 109 // Construct a tensor for positive alpha (double negative). 110 const auto alpha_tensor_name = 111 AvailableArrayName(*model, neg_alpha_tensor_name + "_neg"); 112 model->GetOrCreateArray(alpha_tensor_name); 113 114 auto* neg_neg_alpha_op = new NegOperator; 115 neg_neg_alpha_op->inputs = {neg_alpha_tensor_name}; 116 neg_neg_alpha_op->outputs = {alpha_tensor_name}; 117 model->operators.emplace(add_op_it, neg_neg_alpha_op); 118 119 auto* prelu_op = new PReluOperator; 120 prelu_op->inputs = {input_tensor_name, alpha_tensor_name}; 121 prelu_op->outputs = {output_tensor_name}; 122 model->operators.emplace(add_op_it, prelu_op); 123 AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op)); 124 125 DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model); 126 DeleteArrayIfUsedOnce(add_op->inputs[0], model); 127 DeleteArrayIfUsedOnce(add_op->inputs[1], model); 128 DeleteArrayIfUsedOnce(mul_op->inputs[1], model); 129 // Remove the existing Add op that outputs the final result. If the other 130 // intermediate tensors aren't used by other ops, those will be removed by 131 // other graph transformation rules. 132 model->operators.erase(FindOp(*model, add_op)); 133 *modified = true; 134 return ::tensorflow::Status::OK(); 135 } 136 137 } // namespace toco 138