Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "CpuOperationUtils.h"
     17 #include "IndexedShapeWrapper.h"
     18 #include "OperationResolver.h"
     19 
     20 #include <vector>
     21 
     22 namespace android {
     23 namespace nn {
     24 namespace slice {
     25 
     26 constexpr char kOperationName[] = "SLICE";
     27 
     28 constexpr uint32_t kNumInputs = 3;
     29 constexpr uint32_t kInputTensor = 0;
     30 constexpr uint32_t kBeginTensor = 1;
     31 constexpr uint32_t kSizeTensor = 2;
     32 
     33 constexpr uint32_t kNumOutputs = 1;
     34 constexpr uint32_t kOutputTensor = 0;
     35 
     36 namespace {
     37 
     38 template <typename T>
     39 void addVectors(const std::vector<T>& a, const std::vector<T>& b, std::vector<T>* res) {
     40     for (int i = 0; i < res->size(); ++i) {
     41         res->at(i) = a[i] + b[i];
     42     }
     43 }
     44 
     45 template <typename T>
     46 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t* beginData,
     47                  const Shape& beginShape, const int32_t* sizeData, const Shape& sizeShape,
     48                  T* outputData, const Shape& outputShape) {
     49     const int outputSize = getNumberOfElements(outputShape);
     50     const IndexedShapeWrapper indexedOutput = IndexedShapeWrapper(outputShape);
     51     const IndexedShapeWrapper indexedInput = IndexedShapeWrapper(inputShape);
     52     std::vector<uint32_t> outputIndex(getNumberOfDimensions(outputShape), 0);
     53     std::vector<uint32_t> beginIndex(getSizeOfDimension(beginShape, 0));
     54     std::vector<uint32_t> inputIndex(getNumberOfDimensions(inputShape));
     55 
     56     for (int i = 0; i < beginIndex.size(); ++i) {
     57         beginIndex[i] = static_cast<uint32_t>(beginData[i]);
     58     }
     59 
     60     bool lastIndex = false;
     61     uint32_t outputOffset;
     62     uint32_t inputOffset;
     63 
     64     do {
     65         addVectors(outputIndex, beginIndex, &inputIndex);
     66 
     67         NN_RET_CHECK(indexedOutput.indexToFlatIndex(outputIndex, &outputOffset));
     68         NN_RET_CHECK(indexedInput.indexToFlatIndex(inputIndex, &inputOffset));
     69 
     70         outputData[outputOffset] = inputData[inputOffset];
     71         NN_RET_CHECK(indexedOutput.nextIndexInplace(&outputIndex, &lastIndex));
     72     } while (!lastIndex);
     73     return true;
     74 }
     75 
     76 }  // namespace
     77 
     78 bool validate(const IOperationValidationContext* context) {
     79     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     80     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     81 
     82     const OperandType inputType = context->getInputType(kInputTensor);
     83     NN_RET_CHECK(
     84             inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
     85             inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM)
     86             << "Unsupported tensor type for operation " << kOperationName;
     87     NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
     88     return validateInputTypes(context,
     89                               {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}) &&
     90            validateOutputTypes(context, {inputType});
     91 }
     92 
     93 bool prepare(IOperationExecutionContext* context) {
     94     const Shape& inputShape = context->getInputShape(kInputTensor);
     95     const int32_t n_dims = getNumberOfDimensions(inputShape);
     96     NN_RET_CHECK(n_dims > 0);
     97 
     98     const Shape& beginShape = context->getInputShape(kBeginTensor);
     99     NN_RET_CHECK_EQ(getNumberOfDimensions(beginShape), 1);
    100     NN_RET_CHECK_EQ(getSizeOfDimension(beginShape, 0), n_dims);
    101 
    102     const Shape& sizeShape = context->getInputShape(kSizeTensor);
    103     NN_RET_CHECK_EQ(getNumberOfDimensions(sizeShape), 1);
    104     NN_RET_CHECK_EQ(getSizeOfDimension(sizeShape, 0), n_dims);
    105 
    106     const int32_t* beginData = context->getInputBuffer<int32_t>(kBeginTensor);
    107     const int32_t* sizeData = context->getInputBuffer<int32_t>(kSizeTensor);
    108 
    109     Shape outputShape = context->getOutputShape(kOutputTensor);
    110     outputShape.dimensions.resize(n_dims);
    111     for (int i = 0; i < n_dims; ++i) {
    112         const int32_t sliceBegin = beginData[i];
    113         int32_t sliceSize = sizeData[i];
    114         if (sliceSize == -1) {
    115             sliceSize = getSizeOfDimension(inputShape, i) - sliceBegin;
    116         }
    117         NN_RET_CHECK_LE(beginData[i], getSizeOfDimension(inputShape, i));
    118         NN_RET_CHECK_GE(sliceSize, 0);
    119         NN_RET_CHECK_LE(sliceBegin + sliceSize, getSizeOfDimension(inputShape, i));
    120         outputShape.dimensions[i] = sliceSize;
    121     }
    122     return context->setOutputShape(kOutputTensor, outputShape);
    123 }
    124 
    125 bool execute(IOperationExecutionContext* context) {
    126     // Bypass execution in the case of zero-sized input.
    127     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    128     switch (context->getInputType(kInputTensor)) {
    129         case OperandType::TENSOR_FLOAT16:
    130             return evalGeneric(context->getInputBuffer<_Float16>(kInputTensor),
    131                                context->getInputShape(kInputTensor),
    132                                context->getInputBuffer<int32_t>(kBeginTensor),
    133                                context->getInputShape(kBeginTensor),
    134                                context->getInputBuffer<int32_t>(kSizeTensor),
    135                                context->getInputShape(kSizeTensor),
    136                                context->getOutputBuffer<_Float16>(kOutputTensor),
    137                                context->getOutputShape(kOutputTensor));
    138         case OperandType::TENSOR_FLOAT32:
    139             return evalGeneric(context->getInputBuffer<float>(kInputTensor),
    140                                context->getInputShape(kInputTensor),
    141                                context->getInputBuffer<int32_t>(kBeginTensor),
    142                                context->getInputShape(kBeginTensor),
    143                                context->getInputBuffer<int32_t>(kSizeTensor),
    144                                context->getInputShape(kSizeTensor),
    145                                context->getOutputBuffer<float>(kOutputTensor),
    146                                context->getOutputShape(kOutputTensor));
    147         case OperandType::TENSOR_INT32:
    148             return evalGeneric(context->getInputBuffer<int32_t>(kInputTensor),
    149                                context->getInputShape(kInputTensor),
    150                                context->getInputBuffer<int32_t>(kBeginTensor),
    151                                context->getInputShape(kBeginTensor),
    152                                context->getInputBuffer<int32_t>(kSizeTensor),
    153                                context->getInputShape(kSizeTensor),
    154                                context->getOutputBuffer<int32_t>(kOutputTensor),
    155                                context->getOutputShape(kOutputTensor));
    156         case OperandType::TENSOR_QUANT8_ASYMM:
    157             return evalGeneric(context->getInputBuffer<uint8_t>(kInputTensor),
    158                                context->getInputShape(kInputTensor),
    159                                context->getInputBuffer<int32_t>(kBeginTensor),
    160                                context->getInputShape(kBeginTensor),
    161                                context->getInputBuffer<int32_t>(kSizeTensor),
    162                                context->getInputShape(kSizeTensor),
    163                                context->getOutputBuffer<uint8_t>(kOutputTensor),
    164                                context->getOutputShape(kOutputTensor));
    165         default:
    166             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    167     }
    168 }
    169 
    170 }  // namespace slice
    171 
    172 NN_REGISTER_OPERATION(SLICE, slice::kOperationName, slice::validate, slice::prepare, slice::execute,
    173                       .allowZeroSizedInput = true);
    174 
    175 }  // namespace nn
    176 }  // namespace android
    177