1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <cassert> 16 #include <cmath> 17 #include <cstdio> 18 #include <cstdlib> 19 #include <iostream> 20 #include <limits> 21 22 #include "tensorflow/lite/c/builtin_op_data.h" 23 #include "tensorflow/lite/c/c_api_internal.h" 24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" 25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 26 #include "tensorflow/lite/kernels/internal/tensor.h" 27 #include "tensorflow/lite/kernels/kernel_util.h" 28 #include "tensorflow/lite/kernels/op_macros.h" 29 30 namespace tflite { 31 namespace ops { 32 namespace builtin { 33 namespace concatenation { 34 35 // This file has two implementation of Concatenation. 36 enum KernelType { 37 kReference, 38 kGenericOptimized, 39 }; 40 41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 42 auto* params = 43 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); 44 int axis = params->axis; 45 int num_inputs = node->inputs->size; 46 47 // The number of dimensions of the input tensors must match, and all 48 // dimensions except 'axis' must be equal. 49 TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; 50 TfLiteType input_type = t0->type; 51 if (axis < 0) axis += t0->dims->size; 52 TF_LITE_ENSURE(context, axis >= 0); 53 TF_LITE_ENSURE(context, axis < t0->dims->size); 54 55 // TODO(ahentz): These are limitations of our implementation that could be 56 // removed with a bit of effort. 57 TF_LITE_ENSURE(context, t0->dims->size <= 4); 58 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); 59 TF_LITE_ENSURE(context, 60 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || 61 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || 62 input_type == kTfLiteInt32 || input_type == kTfLiteInt64); 63 64 // Output dimensions will match input dimensions, except 'axis', which 65 // will be the sum of inputs 66 int sum_axis = t0->dims->data[axis]; 67 for (int i = 1; i < num_inputs; ++i) { 68 TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; 69 TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); 70 TF_LITE_ENSURE_EQ(context, t->type, input_type); 71 for (int d = 0; d < t0->dims->size; ++d) { 72 if (d == axis) { 73 sum_axis += t->dims->data[axis]; 74 } else { 75 TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); 76 } 77 } 78 } 79 80 TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); 81 for (int d = 0; d < t0->dims->size; ++d) { 82 output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; 83 } 84 85 TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; 86 TF_LITE_ENSURE_EQ(context, output->type, input_type); 87 88 if (input_type == kTfLiteInt8) { 89 // Make sure there is no re-scaling needed for Int8 quantized kernel. This 90 // is a restriction we introduced to Int8 kernels. 91 VectorOfTensors<int8_t> all_inputs(*context, *node->inputs); 92 for (int i = 0; i < node->inputs->size; ++i) { 93 TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; 94 TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale); 95 TF_LITE_ENSURE_EQ(context, t->params.zero_point, 96 output->params.zero_point); 97 } 98 } 99 100 return context->ResizeTensor(context, output, output_size); 101 } 102 103 template <KernelType kernel_type> 104 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 105 auto* params = 106 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); 107 int axis = params->axis; 108 TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; 109 if (axis < 0) axis += output->dims->size; 110 111 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should 112 // allocate and populate these during Prepare(). 113 // TODO(ycling): Activation function parameter is ignored. For now we dont have 114 // a model with a Concatenation with fused activation function. 115 #define TF_LITE_CONCATENATION(type, scalar) \ 116 { \ 117 VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \ 118 tflite::ConcatenationParams op_params; \ 119 op_params.axis = axis; \ 120 op_params.inputs_count = node->inputs->size; \ 121 type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \ 122 GetTensorShape(output), \ 123 GetTensorData<scalar>(output)); \ 124 } 125 126 #define TF_LITE_CONCATENATION_QUANTIZED(type) \ 127 { \ 128 VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ 129 tflite::ConcatenationParams op_params; \ 130 op_params.axis = axis; \ 131 op_params.input_zeropoint = all_inputs.zero_point(); \ 132 op_params.input_scale = all_inputs.scale(); \ 133 op_params.inputs_count = node->inputs->size; \ 134 op_params.output_zeropoint = output->params.zero_point; \ 135 op_params.output_scale = output->params.scale; \ 136 type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \ 137 all_inputs.data(), GetTensorShape(output), \ 138 GetTensorData<uint8>(output)); \ 139 } 140 141 switch (output->type) { // Already know in/outtypes are same. 142 case kTfLiteFloat32: 143 if (kernel_type == kReference) { 144 TF_LITE_CONCATENATION(reference_ops, float); 145 } else { 146 TF_LITE_CONCATENATION(optimized_ops, float); 147 } 148 break; 149 case kTfLiteInt32: 150 if (kernel_type == kReference) { 151 TF_LITE_CONCATENATION(reference_ops, int32); 152 } else { 153 TF_LITE_CONCATENATION(optimized_ops, int32); 154 } 155 break; 156 case kTfLiteUInt8: 157 if (kernel_type == kReference) { 158 TF_LITE_CONCATENATION_QUANTIZED(reference_ops); 159 } else { 160 TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); 161 } 162 break; 163 case kTfLiteInt8: { 164 if (kernel_type == kReference) { 165 TF_LITE_CONCATENATION(reference_ops, int8_t); 166 } else { 167 TF_LITE_CONCATENATION(optimized_ops, int8_t); 168 } 169 } break; 170 case kTfLiteInt64: 171 if (kernel_type == kReference) { 172 TF_LITE_CONCATENATION(reference_ops, int64_t); 173 } else { 174 TF_LITE_CONCATENATION(optimized_ops, int64_t); 175 } 176 break; 177 178 default: 179 context->ReportError(context, 180 "Only float32 and uint8 are currently supported."); 181 return kTfLiteError; 182 } 183 184 #undef TF_LITE_CONCATENATION_QUANTIZED 185 #undef TF_LITE_CONCATENATION 186 187 return kTfLiteOk; 188 } 189 190 #undef TF_LITE_MACRO_DISPATCH 191 192 } // namespace concatenation 193 194 TfLiteRegistration* Register_CONCATENATION_REF() { 195 static TfLiteRegistration r = { 196 nullptr, nullptr, concatenation::Prepare, 197 concatenation::Eval<concatenation::kReference>}; 198 return &r; 199 } 200 201 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { 202 static TfLiteRegistration r = { 203 nullptr, nullptr, concatenation::Prepare, 204 concatenation::Eval<concatenation::kGenericOptimized>}; 205 return &r; 206 } 207 208 TfLiteRegistration* Register_CONCATENATION() { 209 // TODO(ahentz): It turns out the two versions of Concatenation are almost 210 // identical, so we should consider removing one. 211 return Register_CONCATENATION_GENERIC_OPT(); 212 } 213 214 } // namespace builtin 215 } // namespace ops 216 } // namespace tflite 217