1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/lite/c/builtin_op_data.h" 17 #include "tensorflow/lite/c/c_api_internal.h" 18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 19 #include "tensorflow/lite/kernels/internal/tensor.h" 20 #include "tensorflow/lite/kernels/kernel_util.h" 21 22 namespace tflite { 23 namespace ops { 24 namespace builtin { 25 namespace pack { 26 namespace { 27 28 constexpr int kOutputTensor = 0; 29 30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 31 const TfLitePackParams* data = 32 reinterpret_cast<TfLitePackParams*>(node->builtin_data); 33 34 TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); 35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 36 37 const TfLiteTensor* input0 = GetInput(context, node, 0); 38 TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis); 39 // TODO(renjieliu): Support negative axis. 40 TF_LITE_ENSURE(context, data->axis >= 0); 41 if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 && 42 input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 && 43 input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) { 44 context->ReportError(context, "Type '%s' is not supported by pack.", 45 TfLiteTypeGetName(input0->type)); 46 return kTfLiteError; 47 } 48 // Make sure all inputs have the same shape and type. 49 for (int i = 1; i < data->values_count; ++i) { 50 const TfLiteTensor* input = GetInput(context, node, i); 51 TF_LITE_ENSURE(context, HaveSameShapes(input0, input)); 52 TF_LITE_ENSURE_EQ(context, input0->type, input->type); 53 } 54 55 // Resize output. rank R will become rank R + 1 56 const int dimension_size = NumDimensions(input0) + 1; 57 const TfLiteIntArray* input_shape = input0->dims; 58 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size); 59 int i = 0; 60 for (int index = 0; index < dimension_size; ++index) { 61 if (index == data->axis) { 62 output_shape->data[index] = data->values_count; 63 } else { 64 output_shape->data[index] = input_shape->data[i++]; 65 } 66 } 67 68 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 69 TF_LITE_ENSURE_EQ(context, output->type, input0->type); 70 71 // Guarantee input/output quantization params match as we do not support 72 // packing quantized tensors. 73 for (int i = 0; i < data->values_count; i++) { 74 const TfLiteTensor* input = GetInput(context, node, i); 75 TF_LITE_ENSURE_EQ(context, input->params.zero_point, 76 output->params.zero_point); 77 TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); 78 } 79 80 return context->ResizeTensor(context, output, output_shape); 81 } 82 83 template <typename T> 84 void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, 85 int values_count, int axis) { 86 VectorOfTensors<T> all_inputs(*context, *node->inputs); 87 tflite::PackParams op_params; 88 op_params.axis = axis; 89 op_params.inputs_count = values_count; 90 91 reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(), 92 GetTensorShape(output), GetTensorData<T>(output)); 93 } 94 95 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 96 const TfLitePackParams* data = 97 reinterpret_cast<TfLitePackParams*>(node->builtin_data); 98 99 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 100 switch (output->type) { 101 case kTfLiteFloat32: { 102 PackImpl<float>(context, node, output, data->values_count, data->axis); 103 break; 104 } 105 case kTfLiteUInt8: { 106 PackImpl<uint8_t>(context, node, output, data->values_count, data->axis); 107 break; 108 } 109 case kTfLiteInt8: { 110 PackImpl<int8_t>(context, node, output, data->values_count, data->axis); 111 break; 112 } 113 case kTfLiteInt32: { 114 PackImpl<int32_t>(context, node, output, data->values_count, data->axis); 115 break; 116 } 117 case kTfLiteInt64: { 118 PackImpl<int64_t>(context, node, output, data->values_count, data->axis); 119 break; 120 } 121 default: { 122 context->ReportError(context, "Type '%s' is not supported by pack.", 123 TfLiteTypeGetName(output->type)); 124 return kTfLiteError; 125 } 126 } 127 128 return kTfLiteOk; 129 } 130 131 } // namespace 132 } // namespace pack 133 134 TfLiteRegistration* Register_PACK() { 135 static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval}; 136 return &r; 137 } 138 139 } // namespace builtin 140 } // namespace ops 141 } // namespace tflite 142