1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/c/c_api_internal.h" 16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 17 #include "tensorflow/lite/kernels/internal/tensor.h" 18 #include "tensorflow/lite/kernels/kernel_util.h" 19 #include "tensorflow/lite/kernels/op_macros.h" 20 21 namespace tflite { 22 namespace ops { 23 namespace builtin { 24 namespace pow { 25 namespace { 26 27 // Input/output tensor index. 28 constexpr int kInputTensor1 = 0; 29 constexpr int kInputTensor2 = 1; 30 constexpr int kOutputTensor = 0; 31 32 // Op data for pow op. 33 struct OpData { 34 bool requires_broadcast; 35 }; 36 37 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 38 auto* data = new OpData; 39 data->requires_broadcast = false; 40 return data; 41 } 42 43 void Free(TfLiteContext* context, void* buffer) { 44 delete reinterpret_cast<OpData*>(buffer); 45 } 46 47 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 48 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 49 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 50 51 OpData* data = reinterpret_cast<OpData*>(node->user_data); 52 53 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 54 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 55 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 56 57 TF_LITE_ENSURE_EQ(context, input1->type, input2->type); 58 59 const TfLiteType type = input1->type; 60 if (type != kTfLiteInt32 && type != kTfLiteFloat32) { 61 context->ReportError(context, "Unsupported data type %d.", type); 62 return kTfLiteError; 63 } 64 output->type = type; 65 66 data->requires_broadcast = !HaveSameShapes(input1, input2); 67 68 TfLiteIntArray* output_size = nullptr; 69 if (data->requires_broadcast) { 70 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( 71 context, input1, input2, &output_size)); 72 } else { 73 output_size = TfLiteIntArrayCopy(input1->dims); 74 } 75 76 return context->ResizeTensor(context, output, output_size); 77 } 78 79 template <typename T> 80 void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2, 81 TfLiteTensor* output, bool requires_broadcast) { 82 if (requires_broadcast) { 83 reference_ops::BroadcastPow4DSlow( 84 GetTensorShape(input1), GetTensorData<T>(input1), 85 GetTensorShape(input2), GetTensorData<T>(input2), 86 GetTensorShape(output), GetTensorData<T>(output)); 87 } else { 88 reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1), 89 GetTensorShape(input2), GetTensorData<T>(input2), 90 GetTensorShape(output), GetTensorData<T>(output)); 91 } 92 } 93 94 TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) { 95 const int64_t num_elements = NumElements(input); 96 const int32_t* data = GetTensorData<int32_t>(input); 97 for (int i = 0; i < num_elements; ++i) { 98 if (data[i] < 0) { 99 context->ReportError(context, 100 "POW does not support negative value for int32."); 101 return kTfLiteError; 102 } 103 } 104 return kTfLiteOk; 105 } 106 107 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 108 OpData* data = reinterpret_cast<OpData*>(node->user_data); 109 110 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); 111 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); 112 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 113 114 switch (output->type) { 115 case kTfLiteInt32: { 116 // TensorFlow does not support negative for int32. 117 TF_LITE_ENSURE_OK(context, CheckValue(context, input2)); 118 PowImpl<int32_t>(input1, input2, output, data->requires_broadcast); 119 break; 120 } 121 case kTfLiteFloat32: { 122 PowImpl<float>(input1, input2, output, data->requires_broadcast); 123 break; 124 } 125 default: { 126 context->ReportError(context, "Unsupported data type: %d", output->type); 127 return kTfLiteError; 128 } 129 } 130 return kTfLiteOk; 131 } 132 133 } // namespace 134 } // namespace pow 135 136 TfLiteRegistration* Register_POW() { 137 static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval}; 138 return &r; 139 } 140 141 } // namespace builtin 142 } // namespace ops 143 } // namespace tflite 144