1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string.h> 16 #include "tensorflow/contrib/lite/builtin_op_data.h" 17 #include "tensorflow/contrib/lite/context.h" 18 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 19 #include "tensorflow/contrib/lite/kernels/internal/tensor.h" 20 #include "tensorflow/contrib/lite/kernels/kernel_util.h" 21 #include "tensorflow/contrib/lite/kernels/op_macros.h" 22 #include "tensorflow/contrib/lite/string_util.h" 23 24 namespace tflite { 25 namespace ops { 26 namespace builtin { 27 namespace gather { 28 constexpr int kInputTensor = 0; 29 constexpr int kInputPositions = 1; 30 constexpr int kOutputTensor = 0; 31 32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 35 36 const auto* params = 37 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data); 38 TfLiteTensor* input = GetInput(context, node, kInputTensor); 39 TfLiteTensor* positions = GetInput(context, node, kInputPositions); 40 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 41 // Only INT32 positions are supported. 42 TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); 43 // Check that input and output types match. 44 TF_LITE_ENSURE_EQ(context, input->type, output->type); 45 // TODO(mgubin): only 0D or 1D positions are currently supported. 46 TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); 47 // TODO(mgubin): Only default axis == 0 is supported. 48 TF_LITE_ENSURE_EQ(context, params->axis, 0); 49 // Check conditions for different types. 50 switch (input->type) { 51 case kTfLiteFloat32: 52 case kTfLiteUInt8: 53 case kTfLiteInt32: { 54 // Fully supported by reference_ops::Gather. 55 } break; 56 57 case kTfLiteString: { 58 // Only 1D input is supported. 59 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); 60 } break; 61 default: 62 context->ReportError(context, 63 "Only float32 and string types are supported"); 64 return kTfLiteError; 65 } 66 const int num_dimensions = 67 NumDimensions(input) + NumDimensions(positions) - 1; 68 TF_LITE_ENSURE(context, params->axis <= num_dimensions); 69 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); 70 int output_index = 0; 71 for (int i = 0; i < params->axis; ++i) { 72 output_shape->data[output_index++] = input->dims->data[i]; 73 } 74 for (int i = 0; i < positions->dims->size; ++i) { 75 output_shape->data[output_index++] = positions->dims->data[i]; 76 } 77 for (int i = params->axis + 1; i < input->dims->size; ++i) { 78 output_shape->data[output_index++] = input->dims->data[i]; 79 } 80 return context->ResizeTensor(context, output, output_shape); 81 } 82 83 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 84 TfLiteTensor* input = GetInput(context, node, kInputTensor); 85 TfLiteTensor* positions = GetInput(context, node, kInputPositions); 86 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 87 const int input_rank = NumDimensions(input); 88 #define TF_LITE_GATHER(data_type, index_type) \ 89 optimized_ops::Gather( \ 90 GetTensorData<data_type>(input), GetTensorDims(input), input_rank, \ 91 GetTensorData<index_type>(positions), GetTensorDims(positions), \ 92 GetTensorData<data_type>(output), GetTensorDims(output)); 93 switch (input->type) { 94 case kTfLiteFloat32: 95 TF_LITE_GATHER(float, int32_t); 96 break; 97 case kTfLiteUInt8: 98 TF_LITE_GATHER(uint8_t, int32_t); 99 break; 100 case kTfLiteInt32: 101 TF_LITE_GATHER(int32_t, int32_t); 102 break; 103 case kTfLiteString: { 104 DynamicBuffer buffer; 105 const int32* indexes = positions->data.i32; 106 const int num_strings = GetStringCount(input); 107 for (int i = 0; i < positions->dims->data[0]; ++i) { 108 const int pos = indexes[i]; 109 TF_LITE_ENSURE(context, pos < num_strings); 110 const auto string_ref = GetString(input, pos); 111 buffer.AddString(string_ref.str, string_ref.len); 112 } 113 buffer.WriteToTensor(output); 114 } break; 115 default: 116 return kTfLiteError; 117 } 118 #undef TF_LITE_GATHER 119 return kTfLiteOk; 120 } 121 } // namespace gather 122 123 TfLiteRegistration* Register_GATHER() { 124 static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, 125 gather::Eval}; 126 return &r; 127 } 128 129 } // namespace builtin 130 } // namespace ops 131 } // namespace tflite 132