1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains the implementation of the operations. 18 19 #define LOG_TAG "Operations" 20 21 #include "CpuOperationUtils.h" 22 #include "Operations.h" 23 24 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" 25 26 #include "Tracing.h" 27 28 namespace android { 29 namespace nn { 30 31 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 32 const int32_t* beginData, const int32_t* endData, 33 const int32_t* stridesData, int32_t beginMask, int32_t endMask, 34 int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape) { 35 NNTRACE_TRANS("stridedSliceGeneric"); 36 // This Op only supports 1-4D cases and since we use the reference 4D 37 // implementation, the 1-3D tensors are mapped to 4D. 38 const int kMaxDim = 4; 39 40 std::vector<int> starts; 41 std::vector<int> stops; 42 std::vector<int> strides; 43 44 int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape)); 45 for (int32_t idx = numInputDims - 1; idx >= 0; --idx) { 46 starts.emplace_back(beginData[idx]); 47 stops.emplace_back(endData[idx]); 48 strides.emplace_back(stridesData[idx]); 49 } 50 51 for (int i = numInputDims; i < kMaxDim; i++) { 52 starts.emplace_back(0); 53 stops.emplace_back(1); 54 strides.emplace_back(1); 55 } 56 57 beginMask = ReverseMaskBits(beginMask, numInputDims); 58 endMask = ReverseMaskBits(endMask, numInputDims); 59 shrinkAxisMask = ReverseMaskBits(shrinkAxisMask, numInputDims); 60 61 if (inputShape.type == OperandType::TENSOR_FLOAT32) { 62 NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float"); 63 tflite::reference_ops::StridedSlice( 64 reinterpret_cast<const float*>(inputData), convertShapeToDims(inputShape), 65 beginMask, endMask, shrinkAxisMask, starts, stops, strides, 66 reinterpret_cast<float*>(outputData), convertShapeToDims(outputShape)); 67 } else if (inputShape.type == OperandType::TENSOR_FLOAT16) { 68 NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float16"); 69 tflite::reference_ops::StridedSlice( 70 reinterpret_cast<const _Float16*>(inputData), convertShapeToDims(inputShape), 71 beginMask, endMask, shrinkAxisMask, starts, stops, strides, 72 reinterpret_cast<_Float16*>(outputData), convertShapeToDims(outputShape)); 73 } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) { 74 NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::uint8"); 75 tflite::reference_ops::StridedSlice( 76 reinterpret_cast<const uint8_t*>(inputData), convertShapeToDims(inputShape), 77 beginMask, endMask, shrinkAxisMask, starts, stops, strides, 78 reinterpret_cast<uint8_t*>(outputData), convertShapeToDims(outputShape)); 79 } else { 80 LOG(ERROR) << "Unsupported data type"; 81 return false; 82 } 83 84 return true; 85 } 86 87 } // namespace nn 88 } // namespace android 89