Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "OperationsUtils"
     18 
     19 #include "OperationsUtils.h"
     20 #include "Operations.h"
     21 #include "Utils.h"
     22 
     23 #include <cmath>
     24 
     25 namespace android {
     26 namespace nn {
     27 
     28 namespace {
     29 
     30 bool validateOperandTypes(const std::vector<OperandType>& expectedTypes, const char* tag,
     31                           uint32_t operandCount,
     32                           std::function<OperandType(uint32_t)> getOperandType) {
     33     NN_RET_CHECK_EQ(operandCount, expectedTypes.size());
     34     for (uint32_t i = 0; i < operandCount; ++i) {
     35         OperandType type = getOperandType(i);
     36         NN_RET_CHECK(type == expectedTypes[i])
     37                 << "Invalid " << tag << " tensor type " << toString(type) << " for " << tag << " "
     38                 << i << ", expected " << toString(expectedTypes[i]);
     39     }
     40     return true;
     41 }
     42 
     43 }  // namespace
     44 
     45 bool validateInputTypes(const IOperationValidationContext* context,
     46                         const std::vector<OperandType>& expectedTypes) {
     47     return validateOperandTypes(expectedTypes, "input", context->getNumInputs(),
     48                                 [context](uint32_t index) { return context->getInputType(index); });
     49 }
     50 
     51 bool validateOutputTypes(const IOperationValidationContext* context,
     52                          const std::vector<OperandType>& expectedTypes) {
     53     return validateOperandTypes(
     54             expectedTypes, "output", context->getNumOutputs(),
     55             [context](uint32_t index) { return context->getOutputType(index); });
     56 }
     57 
     58 bool validateHalVersion(const IOperationValidationContext* context,
     59                         HalVersion minSupportedHalVersion) {
     60     if (context->getHalVersion() < minSupportedHalVersion) {
     61         NN_RET_CHECK_FAIL() << "The given inputs and outputs are only supported in "
     62                             << toString(minSupportedHalVersion) << " and later (validating using "
     63                             << toString(context->getHalVersion()) << ")";
     64     }
     65     return true;
     66 }
     67 
     68 bool SameShape(const Shape& in1, const Shape& in2) {
     69     if (in1.type != in2.type || in1.dimensions.size() != in2.dimensions.size()) {
     70         return false;
     71     }
     72     for (size_t i = 0; i < in1.dimensions.size(); i++) {
     73         if (in1.dimensions[i] != in2.dimensions[i]) {
     74             return false;
     75         }
     76     }
     77     return true;
     78 }
     79 
     80 bool SetShape(const Shape& in, Shape* out) {
     81     if (in.type != out->type) {
     82         return false;
     83     }
     84     out->dimensions = in.dimensions;
     85     return true;
     86 }
     87 
     88 bool combineDimensions(const std::vector<uint32_t>& lhs, const std::vector<uint32_t>& rhs,
     89                        std::vector<uint32_t>* combined) {
     90     if (rhs.empty()) {
     91         *combined = lhs;
     92         return true;
     93     }
     94     if (lhs.empty()) {
     95         *combined = rhs;
     96         return true;
     97     }
     98     NN_RET_CHECK_EQ(lhs.size(), rhs.size()) << "incompatible ranks";
     99     combined->resize(lhs.size());
    100     for (uint32_t i = 0; i < lhs.size(); i++) {
    101         if (lhs[i] == 0) {
    102             (*combined)[i] = rhs[i];
    103             continue;
    104         }
    105         if (rhs[i] == 0) {
    106             (*combined)[i] = lhs[i];
    107             continue;
    108         }
    109         NN_RET_CHECK_EQ(lhs[i], rhs[i]) << "incompatible dimension: " << i;
    110         (*combined)[i] = lhs[i];
    111     }
    112     return true;
    113 }
    114 
    115 uint32_t getNumberOfElements(const Shape& shape) {
    116     uint32_t count = 1;
    117     for (size_t i = 0; i < shape.dimensions.size(); i++) {
    118         count *= shape.dimensions[i];
    119     }
    120     return count;
    121 }
    122 
    123 uint32_t getNumberOfElements(const Shape& shape,
    124                              size_t firstAxisInclusive,
    125                              size_t lastAxisExclusive) {
    126     nnAssert(0 <= firstAxisInclusive);
    127     nnAssert(firstAxisInclusive <= lastAxisExclusive);
    128     nnAssert(lastAxisExclusive <= shape.dimensions.size());
    129     uint32_t count = 1;
    130     for (size_t i = firstAxisInclusive; i < lastAxisExclusive; i++) {
    131         count *= shape.dimensions[i];
    132     }
    133     return count;
    134 }
    135 
    136 uint32_t getNumberOfDimensions(const Shape& shape) {
    137     return shape.dimensions.size();
    138 }
    139 
    140 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
    141     nnAssert(0 <= dimensionIdx && dimensionIdx < shape.dimensions.size());
    142     return shape.dimensions[dimensionIdx];
    143 }
    144 
    145 bool handleNegativeAxis(int32_t numberOfDimensions, int32_t* axis) {
    146     NN_CHECK(-numberOfDimensions <= *axis && *axis < numberOfDimensions);
    147     if (*axis < 0) {
    148         *axis += numberOfDimensions;
    149     }
    150     return true;
    151 }
    152 
    153 bool QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift) {
    154     if (double_multiplier == 0.) {
    155         *quantized_multiplier = 0;
    156         *shift = 0;
    157         return true;
    158     }
    159     const double q = std::frexp(double_multiplier, shift);
    160     auto q_fixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
    161     NN_RET_CHECK(q_fixed <= (1ll << 31));
    162     if (q_fixed == (1ll << 31)) {
    163         q_fixed /= 2;
    164         ++*shift;
    165     }
    166     NN_RET_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
    167     *quantized_multiplier = static_cast<int32_t>(q_fixed);
    168     return true;
    169 }
    170 
    171 bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
    172                                       int32_t* quantized_multiplier,
    173                                       int32_t* right_shift) {
    174     NN_OPS_CHECK(double_multiplier >= 0.);
    175     NN_OPS_CHECK(double_multiplier < 1.);
    176     if (double_multiplier == 0.) {
    177         *quantized_multiplier = 0;
    178         *right_shift = 0;
    179         return true;
    180     }
    181     NN_OPS_CHECK(double_multiplier > 0.);
    182     const double q = std::frexp(double_multiplier, right_shift);
    183     *right_shift *= -1;
    184     int64_t q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
    185     NN_OPS_CHECK(q_fixed <= (1LL << 31));
    186     if (q_fixed == (1LL << 31)) {
    187         q_fixed /= 2;
    188         --*right_shift;
    189     }
    190     NN_OPS_CHECK(*right_shift >= 0);
    191     NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
    192     *quantized_multiplier = static_cast<int32_t>(q_fixed);
    193     return true;
    194 }
    195 
    196 bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
    197                                       int32_t* quantized_multiplier,
    198                                       int* left_shift) {
    199     NN_OPS_CHECK(double_multiplier > 1.);
    200     const double q = std::frexp(double_multiplier, left_shift);
    201     int64_t q_fixed = static_cast<int64_t>(std::round(q * (1LL << 31)));
    202     NN_OPS_CHECK(q_fixed <= (1LL << 31));
    203     if (q_fixed == (1LL << 31)) {
    204         q_fixed /= 2;
    205         ++*left_shift;
    206     }
    207     NN_OPS_CHECK(*left_shift >= 0);
    208     NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
    209     *quantized_multiplier = static_cast<int32_t>(q_fixed);
    210     return true;
    211 }
    212 
    213 bool GetQuantizedConvolutionMultipler(const Shape& inputShape, const Shape& filterShape,
    214                                       const Shape& biasShape, const Shape& outputShape,
    215                                       double* multiplier) {
    216     // Upcast bias and input_product to double
    217     const double input_product_scale = inputShape.scale * filterShape.scale;
    218     const double bias_scale = biasShape.scale;
    219 
    220     // The following conditions must be guaranteed by the training pipeline.
    221     NN_OPS_CHECK(std::abs(input_product_scale - bias_scale) <=
    222               1e-6 * std::min(input_product_scale, bias_scale));
    223     NN_OPS_CHECK(input_product_scale >= 0);
    224     *multiplier = input_product_scale / outputShape.scale;
    225     return true;
    226 }
    227 
    228 void CalculateActivationRangeUint8(int32_t activation,
    229                                    const Shape& outputShape,
    230                                    int32_t* act_min,
    231                                    int32_t* act_max) {
    232     const int32_t qmin = std::numeric_limits<uint8_t>::min();
    233     const int32_t qmax = std::numeric_limits<uint8_t>::max();
    234 
    235     const auto scale = outputShape.scale;
    236     const auto zero_point = outputShape.offset;
    237 
    238     auto quantize = [scale, zero_point](float f) {
    239         return zero_point + static_cast<int32_t>(std::round(f / scale));
    240     };
    241 
    242     if (activation == kActivationRelu) {
    243         *act_min = std::max(qmin, quantize(0.0));
    244         *act_max = qmax;
    245     } else if (activation == kActivationRelu6) {
    246         *act_min = std::max(qmin, quantize(0.0));
    247         *act_max = std::min(qmax, quantize(6.0));
    248     } else if (activation == kActivationRelu1) {
    249         *act_min = std::max(qmin, quantize(-1.0));
    250         *act_max = std::min(qmax, quantize(1.0));
    251     } else if (activation == kActivationNone){
    252         *act_min = qmin;
    253         *act_max = qmax;
    254     } else {
    255         LOG(ERROR) << "Unsupported fused activation function.";
    256     }
    257 }
    258 
    259 void CalculateActivationRangeFloat(int32_t activation,
    260                                    float* activation_min,
    261                                    float* activation_max) {
    262     if (activation == kActivationRelu) {
    263         *activation_min = 0.f;
    264         *activation_max = std::numeric_limits<float>::max();
    265     } else if (activation == kActivationRelu6) {
    266         *activation_min = 0.f;
    267         *activation_max = 6.f;
    268     } else if (activation == kActivationRelu1) {
    269         *activation_min = -1.f;
    270         *activation_max = 1.f;
    271     } else if (activation == kActivationNone){
    272         *activation_min = std::numeric_limits<float>::lowest();
    273         *activation_max = std::numeric_limits<float>::max();
    274     } else {
    275         LOG(ERROR) << "Unsupported fused activation function.";
    276     }
    277 }
    278 
    279 int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift) {
    280     const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
    281                                       (1LL << (31 - input_integer_bits)) /
    282                                       (1LL << input_left_shift);
    283     // Tighten bound using floor.  Suppose that we could use the exact value.
    284     // After scaling the difference, the result would be at the maximum.  Thus we
    285     // must ensure that our value has lower magnitude.
    286     return static_cast<int32_t>(std::floor(max_input_rescaled));
    287 }
    288 
    289 void calculateExplicitPaddingImpl(int32_t in_size, int32_t stride, int32_t dilation_factor,
    290                                   int32_t filter_size, int32_t padding_implicit,
    291                                   bool isTransposeConv, int32_t* padding_head,
    292                                   int32_t* padding_tail) {
    293     *padding_head = 0;
    294     *padding_tail = 0;
    295 
    296     int32_t effective_filter_size = (filter_size - 1) * dilation_factor + 1;
    297 
    298     if (padding_implicit == kPaddingSame) {
    299         int32_t out_size = (in_size + stride - 1) / stride;
    300         int32_t tmp = (out_size - 1) * stride + effective_filter_size;
    301         if (tmp > in_size) {
    302             *padding_head = (tmp - in_size) / 2;
    303             *padding_tail = (tmp - in_size) - *padding_head;
    304         }
    305         // For transpose conv, make padding tail fit tightly to the end of the last stride.
    306         if (isTransposeConv) {
    307             *padding_tail = (tmp - in_size) - *padding_head;
    308         }
    309     }
    310 }
    311 
    312 bool calculateBroadcastedShape(const Shape& in1, const Shape& in2, Shape* out) {
    313     NN_RET_CHECK(in1.type == in2.type);
    314     uint32_t numberOfDims1 = getNumberOfDimensions(in1);
    315     uint32_t numberOfDims2 = getNumberOfDimensions(in2);
    316     uint32_t maxDims = std::max(numberOfDims1, numberOfDims2);
    317     out->dimensions = std::vector<uint32_t>(maxDims);
    318     for (uint32_t i = 1; i <= maxDims; i++) {
    319         uint32_t dim1 = 1;
    320         if (i <= numberOfDims1) {
    321             dim1 = getSizeOfDimension(in1, numberOfDims1 - i);
    322         }
    323         uint32_t dim2 = 1;
    324         if (i <= numberOfDims2) {
    325             dim2 = getSizeOfDimension(in2, numberOfDims2 - i);
    326         }
    327         if (dim1 != dim2 && dim1 != 1 && dim2 != 1) {
    328             LOG(ERROR) << "Dimensions mismatch for broadcast:\n"
    329                        << "First tensor: dimension " << numberOfDims1 - i << " of size " << dim1
    330                        << "\nSecond tensor: dimension " << numberOfDims2 - i << "of size " << dim2;
    331             return false;
    332         }
    333         out->dimensions[maxDims - i] = (dim1 == 1) ? dim2 : dim1;
    334     }
    335     return true;
    336 }
    337 
    338 uint8_t requantize(uint8_t value, const Shape& oldShape, const Shape& newShape) {
    339     double doubleValue = (value - oldShape.offset) * oldShape.scale;
    340     double doubleRet = doubleValue / newShape.scale + newShape.offset;
    341     if (doubleRet < 0) return 0;
    342     if (doubleRet > 255) return 255;
    343     return static_cast<uint8_t>(std::round(doubleRet));
    344 }
    345 
    346 bool floorPrepare(const Shape& input, Shape* output) {
    347     return SetShape(input, output);
    348 }
    349 
    350 bool depthwiseConvPrepare(const Shape& input, const Shape& filter, const Shape& bias,
    351                           int32_t padding_left, int32_t padding_right, int32_t padding_top,
    352                           int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
    353                           int32_t depth_multiplier, int32_t dilation_width_factor,
    354                           int32_t dilation_height_factor, Shape* output) {
    355     if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
    356         NN_OPS_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM);
    357     } else {
    358         NN_OPS_CHECK(input.type == filter.type);
    359     }
    360     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
    361         NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
    362     } else {
    363         NN_OPS_CHECK(input.type == bias.type);
    364     }
    365     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    366     NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
    367     NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
    368 
    369     NN_OPS_CHECK(getSizeOfDimension(filter, 3) == getSizeOfDimension(bias, 0));
    370 
    371     uint32_t channels_out = getSizeOfDimension(filter, 3);
    372     uint32_t channels_in = getSizeOfDimension(input, 3);
    373     uint32_t width        = getSizeOfDimension(input, 2);
    374     uint32_t height       = getSizeOfDimension(input, 1);
    375     uint32_t filterWidth  = getSizeOfDimension(filter, 2);
    376     uint32_t filterHeight = getSizeOfDimension(filter, 1);
    377     uint32_t batches      = getSizeOfDimension(input, 0);
    378 
    379     NN_OPS_CHECK(depth_multiplier * channels_in == channels_out);
    380     int32_t effectiveFilterWidth = (filterWidth - 1) * dilation_width_factor + 1;
    381     int32_t effectiveFilterHeight = (filterHeight - 1) * dilation_height_factor + 1;
    382     NN_RET_CHECK_GT(effectiveFilterWidth, padding_left);
    383     NN_RET_CHECK_GT(effectiveFilterWidth, padding_right);
    384     NN_RET_CHECK_GT(effectiveFilterHeight, padding_top);
    385     NN_RET_CHECK_GT(effectiveFilterHeight, padding_bottom);
    386 
    387     uint32_t outWidth = computeOutSize(width, filterWidth, stride_width, dilation_width_factor,
    388                                        padding_left, padding_right);
    389     uint32_t outHeight = computeOutSize(height, filterHeight, stride_height, dilation_height_factor,
    390                                         padding_top, padding_bottom);
    391 
    392     output->type = input.type;
    393     output->dimensions = {batches, outHeight, outWidth, channels_out};
    394     return true;
    395 }
    396 
    397 bool genericActivationPrepare(const Shape& input,
    398                               Shape* output) {
    399     NN_OPS_CHECK(getNumberOfDimensions(input) <= 4);
    400     return SetShape(input, output);
    401 }
    402 
    403 bool genericNormalizationPrepare(const Shape& input, Shape* output) {
    404     return SetShape(input, output);
    405 }
    406 
    407 bool reshapePrepare(const Shape& input,
    408                     const int32_t* targetDims,
    409                     const int32_t targetDimsSize,
    410                     Shape* output) {
    411     // Reshape allows one of the targetDims components to have the
    412     // special -1 value, meaning it will be calculated automatically based on the
    413     // input. Here we calculate what that dimension should be so that the number
    414     // of output elements in the same as the number of input elements.
    415     int32_t numInputElements = (int32_t) getNumberOfElements(input);
    416 
    417     std::vector<uint32_t> outDims(targetDimsSize);
    418     int32_t numOutputElements = 1;
    419     int32_t strechDim = -1;
    420     for (int32_t i = 0; i < targetDimsSize; ++i) {
    421         int32_t value = targetDims[i];
    422         if (value == -1) {
    423             NN_OPS_CHECK(strechDim == -1);
    424             strechDim = i;
    425         } else {
    426             numOutputElements *= value;
    427             outDims[i] = (uint32_t)value;
    428         }
    429     }
    430     if (strechDim != -1) {
    431         int32_t strechValue = numInputElements / numOutputElements;
    432         outDims[strechDim] = (uint32_t) strechValue;
    433         numOutputElements *= strechValue;
    434     }
    435 
    436     NN_OPS_CHECK(numInputElements == numOutputElements);
    437 
    438     output->type = input.type;
    439     output->dimensions = outDims;
    440     output->offset = input.offset;
    441     output->scale = input.scale;
    442 
    443     return true;
    444 }
    445 
    446 bool depthToSpacePrepare(const Shape& input,
    447                          int32_t blockSize,
    448                          Shape* output) {
    449     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    450     NN_OPS_CHECK(blockSize > 0);
    451 
    452     uint32_t batches  = getSizeOfDimension(input, 0);
    453     uint32_t height   = getSizeOfDimension(input, 1);
    454     uint32_t width    = getSizeOfDimension(input, 2);
    455     uint32_t channels = getSizeOfDimension(input, 3);
    456 
    457     NN_OPS_CHECK(channels % (blockSize * blockSize) == 0);
    458     output->type = input.type;
    459     output->dimensions = {batches,
    460                           height * blockSize,
    461                           width * blockSize,
    462                           channels / (blockSize * blockSize)};
    463     output->offset = input.offset;
    464     output->scale = input.scale;
    465 
    466     return true;
    467 }
    468 
    469 bool spaceToDepthPrepare(const Shape& input,
    470                          int32_t blockSize,
    471                          Shape* output) {
    472     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    473     NN_OPS_CHECK(blockSize > 0);
    474 
    475     uint32_t batches  = getSizeOfDimension(input, 0);
    476     uint32_t height   = getSizeOfDimension(input, 1);
    477     uint32_t width    = getSizeOfDimension(input, 2);
    478     uint32_t channels = getSizeOfDimension(input, 3);
    479 
    480     NN_OPS_CHECK(height % blockSize == 0);
    481     NN_OPS_CHECK(width % blockSize == 0);
    482 
    483     output->type = input.type;
    484     output->dimensions = {batches,
    485                           height / blockSize,
    486                           width / blockSize,
    487                           channels * (blockSize * blockSize)};
    488     output->offset = input.offset;
    489     output->scale = input.scale;
    490 
    491     return true;
    492 }
    493 
    494 bool embeddingLookupPrepare(const Shape &valueShape,
    495                             const Shape &lookupShape,
    496                             Shape *outputShape) {
    497     NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 2);
    498     NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
    499 
    500     const uint32_t rows     = getSizeOfDimension(valueShape, 0);
    501     const uint32_t columns  = getSizeOfDimension(valueShape, 1);
    502 
    503     const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
    504 
    505     outputShape->type = valueShape.type;
    506     outputShape->dimensions = { lookups, columns };
    507     for (uint32_t i = 2; i < getNumberOfDimensions(valueShape); i++) {
    508         outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
    509     }
    510     outputShape->offset = valueShape.offset;
    511     outputShape->scale = valueShape.scale;
    512 
    513     return true;
    514 }
    515 
    516 bool hashtableLookupPrepare(const Shape &lookupShape,
    517                             const Shape &keyShape,
    518                             const Shape &valueShape,
    519                             Shape *outputShape,
    520                             Shape *hitShape) {
    521     NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
    522     NN_OPS_CHECK(getNumberOfDimensions(keyShape) == 1);
    523     NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 1);
    524 
    525     const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
    526     const uint32_t keys     = getSizeOfDimension(keyShape, 0);
    527     const uint32_t rows     = getSizeOfDimension(valueShape, 0);
    528     outputShape->type = valueShape.type;
    529     outputShape->dimensions = { lookups };
    530     for (uint32_t i = 1; i < getNumberOfDimensions(valueShape); i++) {
    531         outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
    532     }
    533     outputShape->offset = valueShape.offset;
    534     outputShape->scale = valueShape.scale;
    535 
    536     hitShape->type = OperandType::TENSOR_QUANT8_ASYMM;
    537     hitShape->dimensions = { lookups };
    538     hitShape->offset = 0;
    539     hitShape->scale = 1.f;
    540 
    541     return true;
    542 }
    543 
    544 bool padPrepare(const Shape& input,
    545                 const int32_t* paddingsData,
    546                 const Shape& paddingsShape,
    547                 Shape* output) {
    548     uint32_t numInputDims = getNumberOfDimensions(input);
    549 
    550     // paddings need to be provided as a 2-D int32 tensor.
    551     NN_OPS_CHECK(paddingsShape.type == OperandType::TENSOR_INT32);
    552     NN_OPS_CHECK(getNumberOfDimensions(paddingsShape) == 2);
    553     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == numInputDims);
    554     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2);
    555 
    556     std::vector<uint32_t> outDims(numInputDims);
    557     for (uint32_t i = 0; i < numInputDims; ++i) {
    558         int32_t beforePadding = *paddingsData++;
    559         int32_t afterPadding = *paddingsData++;
    560         // Pad value has to be greater than equal to 0.
    561         NN_OPS_CHECK(beforePadding >= 0 && afterPadding >= 0);
    562         outDims[i] = beforePadding + getSizeOfDimension(input, i) + afterPadding;
    563     }
    564     output->type = input.type;
    565     output->dimensions = outDims;
    566     output->offset = input.offset;
    567     output->scale = input.scale;
    568 
    569     return true;
    570 }
    571 
    572 bool batchToSpacePrepare(const Shape& input,
    573                          const int32_t* blockSizeData,
    574                          const Shape& blockSizeShape,
    575                          Shape* output) {
    576     // Only 4D NHWC tensors are supported.
    577     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    578 
    579     // blockSize need to be provided as a 1-D int32 tensor.
    580     NN_OPS_CHECK(blockSizeShape.type == OperandType::TENSOR_INT32);
    581     NN_OPS_CHECK(getNumberOfDimensions(blockSizeShape) == 1);
    582     // Only applies to spatial dimensions.
    583     NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2);
    584 
    585     uint32_t batches  = getSizeOfDimension(input, 0);
    586     uint32_t height   = getSizeOfDimension(input, 1);
    587     uint32_t width    = getSizeOfDimension(input, 2);
    588     uint32_t channels = getSizeOfDimension(input, 3);
    589 
    590     NN_OPS_CHECK(batches % (blockSizeData[0] * blockSizeData[1]) == 0);
    591     output->type = input.type;
    592     output->dimensions = {batches / (blockSizeData[0] * blockSizeData[1]),
    593                           height * blockSizeData[0],
    594                           width * blockSizeData[1],
    595                           channels};
    596     output->offset = input.offset;
    597     output->scale = input.scale;
    598 
    599     return true;
    600 }
    601 
    602 bool spaceToBatchPrepare(const Shape& input,
    603                          const int32_t* blockSizeData,
    604                          const Shape& blockSizeShape,
    605                          const int32_t* paddingsData,
    606                          const Shape& paddingsShape,
    607                          Shape* output) {
    608     // Only 4D NHWC tensors are supported.
    609     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    610 
    611     // blockSize need to be provided as a 1-D int32 tensor.
    612     NN_OPS_CHECK(blockSizeShape.type == OperandType::TENSOR_INT32);
    613     NN_OPS_CHECK(getNumberOfDimensions(blockSizeShape) == 1);
    614     // Only applies to spatial dimensions.
    615     NN_OPS_CHECK(getSizeOfDimension(blockSizeShape, 0) == 2);
    616 
    617     // paddings need to be provided as a 2-D int32 tensor.
    618     NN_OPS_CHECK(paddingsShape.type == OperandType::TENSOR_INT32);
    619     NN_OPS_CHECK(getNumberOfDimensions(paddingsShape) == 2);
    620     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 0) == 2);
    621     NN_OPS_CHECK(getSizeOfDimension(paddingsShape, 1) == 2);
    622 
    623     uint32_t batches  = getSizeOfDimension(input, 0);
    624     uint32_t height   = getSizeOfDimension(input, 1);
    625     uint32_t width    = getSizeOfDimension(input, 2);
    626     uint32_t channels = getSizeOfDimension(input, 3);
    627 
    628     uint32_t paddedHeight = paddingsData[0] + height + paddingsData[1];
    629     uint32_t paddedWidth = paddingsData[2] + width + paddingsData[3];
    630 
    631     NN_OPS_CHECK(paddedHeight % blockSizeData[0] == 0);
    632     NN_OPS_CHECK(paddedWidth % blockSizeData[1] == 0);
    633 
    634     output->type = input.type;
    635     output->dimensions = {batches * (blockSizeData[0] * blockSizeData[1]),
    636                           paddedHeight / blockSizeData[0],
    637                           paddedWidth / blockSizeData[1],
    638                           channels};
    639     output->offset = input.offset;
    640     output->scale = input.scale;
    641 
    642     return true;
    643 }
    644 
    645 bool squeezePrepare(const Shape& input,
    646                     const int32_t* squeezeDims,
    647                     const Shape& squeezeDimsShape,
    648                     Shape* output) {
    649     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input));
    650 
    651     // squeezeDims need to be provided as a 1-D int32 tensor.
    652     NN_OPS_CHECK(squeezeDimsShape.type == OperandType::TENSOR_INT32);
    653     NN_OPS_CHECK(getNumberOfDimensions(squeezeDimsShape) == 1);
    654 
    655     int32_t squeezeDimsSize = static_cast<int32_t>(getSizeOfDimension(squeezeDimsShape, 0));
    656     std::vector<bool> shouldSqueeze(numInputDims, false);
    657     int32_t numDimsSqueezed = 0;
    658 
    659     if (squeezeDimsSize == 0) {
    660         // If squeezeDimsSize is 0, all dims with value 1 will be squeezed.
    661         for (int32_t idx = 0; idx < numInputDims; ++idx) {
    662             if (getSizeOfDimension(input, idx) == 1) {
    663                 shouldSqueeze[idx] = true;
    664                 ++numDimsSqueezed;
    665             }
    666         }
    667     } else {
    668         for (int32_t idx = 0; idx < squeezeDimsSize; ++idx) {
    669             int32_t current = squeezeDims[idx] < 0 ? squeezeDims[idx] + numInputDims
    670                                                : squeezeDims[idx];
    671             NN_OPS_CHECK(current >= 0 && current < numInputDims &&
    672                          getSizeOfDimension(input, current) == 1);
    673             if (!shouldSqueeze[current]) ++numDimsSqueezed;
    674             shouldSqueeze[current] = true;
    675       }
    676     }
    677 
    678     // Sets output dimensions.
    679     std::vector<uint32_t> outDims(numInputDims - numDimsSqueezed);
    680     for (int32_t inIdx = 0, outIdx = 0; inIdx < numInputDims; ++inIdx) {
    681         if (!shouldSqueeze[inIdx]) {
    682             outDims[outIdx++] = getSizeOfDimension(input, inIdx);
    683         }
    684     }
    685 
    686     output->type = input.type;
    687     output->dimensions = outDims;
    688     output->offset = input.offset;
    689     output->scale = input.scale;
    690 
    691     return true;
    692 }
    693 
    694 bool meanPrepare(const Shape& input,
    695                  const int32_t* axisData,
    696                  const Shape& axisShape,
    697                  bool keepDims,
    698                  Shape* output) {
    699 
    700     // perm need to be provided as a 1-D int32 tensor.
    701     NN_OPS_CHECK(axisShape.type == OperandType::TENSOR_INT32);
    702     NN_OPS_CHECK(getNumberOfDimensions(axisShape) == 1);
    703 
    704     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(input));
    705     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
    706 
    707     // Determines size of output tensor.
    708     if (keepDims) {
    709         std::vector<uint32_t> outDims(numInputDims);
    710         for (int32_t idx = 0; idx < numInputDims; ++idx) {
    711             bool isAxis = false;
    712             for (int32_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
    713                 if (axisData[axisIdx] == idx || axisData[axisIdx] + numInputDims == idx) {
    714                     isAxis = true;
    715                     break;
    716                 }
    717             }
    718             if (isAxis) {
    719                 outDims[idx] = 1;
    720             } else {
    721                 outDims[idx] = getSizeOfDimension(input, idx);
    722             }
    723         }
    724         output->dimensions = outDims;
    725     } else {
    726         // Calculates size of reducing axis.
    727         int32_t numReduceAxis = axisSize;
    728         for (int32_t i = 0; i < axisSize; ++i) {
    729             int32_t current = axisData[i];
    730             if (current < 0) {
    731                 current += numInputDims;
    732             }
    733             NN_OPS_CHECK(current >= 0 && current < numInputDims);
    734             for (int32_t j = 0; j < i; ++j) {
    735                 int32_t previous = axisData[j];
    736                 if (previous < 0) {
    737                     previous += numInputDims;
    738                 }
    739                 if (current == previous) {
    740                     --numReduceAxis;
    741                     break;
    742                 }
    743             }
    744         }
    745         // Determines output dimensions.
    746         std::vector<uint32_t> outDims(numInputDims - numReduceAxis);
    747         int32_t numSkipAxis = 0;
    748         for (int32_t idx = 0; idx < numInputDims; ++idx) {
    749             bool isAxis = false;
    750             for (int32_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) {
    751                 if (axisData[axisIdx] == idx || axisData[axisIdx] + numInputDims == idx) {
    752                     ++numSkipAxis;
    753                     isAxis = true;
    754                     break;
    755                 }
    756             }
    757             if (!isAxis) {
    758                 outDims[idx - numSkipAxis] = getSizeOfDimension(input, idx);
    759             }
    760         }
    761         output->dimensions = outDims;
    762     }
    763 
    764     output->type = input.type;
    765     output->offset = input.offset;
    766     output->scale = input.scale;
    767 
    768     return true;
    769 }
    770 
    771 bool stridedSlicePrepare(const Shape& input,
    772                          const int32_t* beginData, const Shape& beginShape,
    773                          const int32_t* endData, const Shape& endShape,
    774                          const int32_t* stridesData, const Shape& stridesShape,
    775                          int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
    776                          Shape* output) {
    777     uint32_t numInputDims = getNumberOfDimensions(input);
    778     // StridedSlice op only supports 1D-4D input arrays.
    779     NN_OPS_CHECK(numInputDims <= 4);
    780 
    781     NN_OPS_CHECK(getNumberOfDimensions(beginShape) == 1);
    782     NN_OPS_CHECK(getNumberOfDimensions(endShape) == 1);
    783     NN_OPS_CHECK(getNumberOfDimensions(stridesShape) == 1);
    784 
    785     NN_OPS_CHECK(getSizeOfDimension(beginShape, 0) == numInputDims);
    786     NN_OPS_CHECK(getSizeOfDimension(endShape, 0) == numInputDims);
    787     NN_OPS_CHECK(getSizeOfDimension(stridesShape, 0) == numInputDims);
    788 
    789     NN_OPS_CHECK(beginShape.type == OperandType::TENSOR_INT32);
    790     NN_OPS_CHECK(endShape.type == OperandType::TENSOR_INT32);
    791     NN_OPS_CHECK(stridesShape.type == OperandType::TENSOR_INT32);
    792 
    793     // Determine size of output tensor and map indices
    794     std::vector<uint32_t> outDims;
    795     for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); idx++) {
    796       int32_t dim = static_cast<int32_t>(getSizeOfDimension(input, idx));
    797       int32_t stride = stridesData[idx];
    798       // stride value has to be non-zero
    799       NN_OPS_CHECK(stride != 0);
    800       bool positiveStride = stride > 0;
    801 
    802       int32_t begin = beginMask & (1 << idx)
    803               ? positiveStride ? 0 : dim - 1
    804               : ClampedIndex(beginData[idx], dim, positiveStride);
    805       int32_t end = endMask & (1 << idx)
    806               ? positiveStride ? dim : -1
    807               : ClampedIndex(endData[idx], dim, positiveStride);
    808 
    809       // This is valid for both positive and negative strides
    810       int32_t outDim = ceil((end - begin) / static_cast<float>(stride));
    811       outDim = outDim < 0 ? 0 : static_cast<uint32_t>(outDim);
    812       if (!(shrinkAxisMask & (1 << idx))) {
    813           outDims.push_back(outDim);
    814       } else {
    815           if (outDim != 1) {
    816               LOG(ERROR) << "Outdim " << idx << " is " << outDim << ", expected 1";
    817               NN_OPS_CHECK(outDim == 1);
    818           }
    819       }
    820     }
    821 
    822     output->type = input.type;
    823     output->dimensions = outDims;
    824     output->offset = input.offset;
    825     output->scale = input.scale;
    826 
    827     return true;
    828 }
    829 
    830 bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output) {
    831     NN_CHECK(handleNegativeAxis(input, &axis));
    832 
    833     output->type = OperandType::TENSOR_INT32;
    834 
    835     // Copy the input dimensions, omitting the axis dimension.
    836     output->dimensions.clear();
    837     output->dimensions.reserve(getNumberOfDimensions(input) - 1);
    838     output->dimensions.insert(output->dimensions.end(),
    839                               input.dimensions.begin(),
    840                               input.dimensions.begin() + axis);
    841     output->dimensions.insert(output->dimensions.end(),
    842                               input.dimensions.begin() + axis + 1,
    843                               input.dimensions.end());
    844 
    845     return true;
    846 }
    847 
    848 bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs,
    849                   std::vector<Shape>* output) {
    850     NN_CHECK(handleNegativeAxis(input, &axis));
    851 
    852     const int32_t sizeOfAxisToSplit = input.dimensions[axis];
    853     NN_OPS_CHECK(sizeOfAxisToSplit % numOutputs == 0);
    854     const int32_t sliceSize = sizeOfAxisToSplit / numOutputs;
    855 
    856     for (int i = 0; i < numOutputs; ++i) {
    857         output->at(i).type = input.type;
    858         output->at(i).dimensions = input.dimensions;
    859         output->at(i).dimensions[axis] = sliceSize;
    860         output->at(i).offset = input.offset;
    861         output->at(i).scale = input.scale;
    862     }
    863     return true;
    864 }
    865 
    866 bool groupedConvPrepare(const Shape& input, const Shape& filter, const Shape& bias,
    867                         int32_t padding_left, int32_t padding_right, int32_t padding_top,
    868                         int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
    869                         int32_t numGroups, Shape* output) {
    870     if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
    871         NN_OPS_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM);
    872     } else {
    873         NN_OPS_CHECK(input.type == filter.type);
    874     }
    875     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
    876         NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
    877     } else {
    878         NN_OPS_CHECK(input.type == bias.type);
    879     }
    880     NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
    881     NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
    882     NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
    883 
    884     NN_OPS_CHECK(getSizeOfDimension(filter, 0) == getSizeOfDimension(bias, 0));
    885 
    886     NN_OPS_CHECK(getSizeOfDimension(filter, 3) * numGroups == getSizeOfDimension(input, 3));
    887     NN_OPS_CHECK(getSizeOfDimension(filter, 0) % numGroups == 0);
    888 
    889     uint32_t channels_out = getSizeOfDimension(filter, 0);
    890     uint32_t width = getSizeOfDimension(input, 2);
    891     uint32_t height = getSizeOfDimension(input, 1);
    892     uint32_t filterWidth = getSizeOfDimension(filter, 2);
    893     uint32_t filterHeight = getSizeOfDimension(filter, 1);
    894     uint32_t batches = getSizeOfDimension(input, 0);
    895 
    896     NN_RET_CHECK_GT(static_cast<int32_t>(filterWidth), padding_left);
    897     NN_RET_CHECK_GT(static_cast<int32_t>(filterWidth), padding_right);
    898     NN_RET_CHECK_GT(static_cast<int32_t>(filterHeight), padding_top);
    899     NN_RET_CHECK_GT(static_cast<int32_t>(filterHeight), padding_bottom);
    900 
    901     uint32_t outWidth =
    902             computeOutSize(width, filterWidth, stride_width, padding_left, padding_right);
    903     uint32_t outHeight =
    904             computeOutSize(height, filterHeight, stride_height, padding_top, padding_bottom);
    905 
    906     output->type = input.type;
    907     output->dimensions = {batches, outHeight, outWidth, channels_out};
    908     return true;
    909 }
    910 
    911 } // namespace nn
    912 } // namespace android
    913