Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "HalInterfaces.h"
     20 #include "IndexedShapeWrapper.h"
     21 #include "OperationResolver.h"
     22 #include "OperationsUtils.h"
     23 
     24 namespace android {
     25 namespace nn {
     26 namespace logical {
     27 
     28 constexpr uint32_t kNumInputs = 2;
     29 constexpr uint32_t kInputTensor1 = 0;
     30 constexpr uint32_t kInputTensor2 = 1;
     31 
     32 constexpr uint32_t kNumOutputs = 1;
     33 constexpr uint32_t kOutputTensor = 0;
     34 
     35 namespace {
     36 
     37 bool compute(const std::function<bool(bool, bool)>& func, const bool8* aData, const Shape& aShape,
     38              const bool8* bData, const Shape& bShape, bool8* outputData, const Shape& outputShape) {
     39     IndexedShapeWrapper aShapeIndexed(aShape);
     40     IndexedShapeWrapper bShapeIndexed(bShape);
     41     IndexedShapeWrapper outputShapeIndexed(outputShape);
     42     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
     43     bool lastIndex = false;
     44     do {
     45         uint32_t outputFlatIndex;
     46         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
     47         uint32_t aFlatIndex;
     48         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
     49         uint32_t bFlatIndex;
     50         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
     51 
     52         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
     53 
     54         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
     55     } while (!lastIndex);
     56     return true;
     57 }
     58 
     59 }  // namespace
     60 
     61 bool validate(const IOperationValidationContext* context) {
     62     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     63     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     64     OperandType inputType = context->getInputType(kInputTensor1);
     65     NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
     66             << "Unsupported tensor type for a logical operation";
     67     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
     68     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
     69     return validateHalVersion(context, HalVersion::V1_2);
     70 }
     71 
     72 bool prepare(IOperationExecutionContext* context) {
     73     Shape input1 = context->getInputShape(kInputTensor1);
     74     Shape input2 = context->getInputShape(kInputTensor2);
     75     Shape output = context->getOutputShape(kOutputTensor);
     76     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
     77     return context->setOutputShape(kOutputTensor, output);
     78 }
     79 
     80 bool executeAnd(IOperationExecutionContext* context) {
     81     return compute(
     82             std::logical_and<bool>(), context->getInputBuffer<bool8>(kInputTensor1),
     83             context->getInputShape(kInputTensor1), context->getInputBuffer<bool8>(kInputTensor2),
     84             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
     85             context->getOutputShape(kOutputTensor));
     86 }
     87 
     88 bool executeOr(IOperationExecutionContext* context) {
     89     return compute(
     90             std::logical_or<bool>(), context->getInputBuffer<bool8>(kInputTensor1),
     91             context->getInputShape(kInputTensor1), context->getInputBuffer<bool8>(kInputTensor2),
     92             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
     93             context->getOutputShape(kOutputTensor));
     94 }
     95 
     96 }  // namespace logical
     97 
     98 NN_REGISTER_OPERATION(LOGICAL_AND, "LOGICAL_AND", logical::validate, logical::prepare,
     99                       logical::executeAnd);
    100 NN_REGISTER_OPERATION(LOGICAL_OR, "LOGICAL_OR", logical::validate, logical::prepare,
    101                       logical::executeOr);
    102 
    103 }  // namespace nn
    104 }  // namespace android
    105