Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "HalInterfaces.h"
     20 #include "OperationResolver.h"
     21 #include "OperationsUtils.h"
     22 #include "Tracing.h"
     23 
     24 #include <cmath>
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace elementwise {
     29 
     30 constexpr uint32_t kNumInputs = 1;
     31 constexpr uint32_t kInputTensor = 0;
     32 
     33 constexpr uint32_t kNumOutputs = 1;
     34 constexpr uint32_t kOutputTensor = 0;
     35 
     36 namespace {
     37 
     38 template <typename T>
     39 inline bool compute(float func(float), const T* input, const Shape& shape, T* output) {
     40     const auto size = getNumberOfElements(shape);
     41     for (uint32_t i = 0; i < size; ++i) {
     42         output[i] = static_cast<T>(func(static_cast<float>(input[i])));
     43     }
     44     return true;
     45 }
     46 
     47 bool execute(IOperationExecutionContext* context, float func(float)) {
     48     switch (context->getInputType(kInputTensor)) {
     49         case OperandType::TENSOR_FLOAT16:
     50             return compute(func, context->getInputBuffer<_Float16>(kInputTensor),
     51                            context->getInputShape(kInputTensor),
     52                            context->getOutputBuffer<_Float16>(kOutputTensor));
     53         case OperandType::TENSOR_FLOAT32:
     54             return compute(func, context->getInputBuffer<float>(kInputTensor),
     55                            context->getInputShape(kInputTensor),
     56                            context->getOutputBuffer<float>(kOutputTensor));
     57         default:
     58             NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
     59     }
     60 }
     61 
     62 }  // namespace
     63 
     64 bool validate(const IOperationValidationContext* context) {
     65     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     66     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     67     OperandType inputType = context->getInputType(kInputTensor);
     68     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
     69                  inputType == OperandType::TENSOR_FLOAT32)
     70             << "Unsupported tensor type for elementwise operation";
     71     NN_RET_CHECK(validateInputTypes(context, {inputType}));
     72     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
     73     return validateHalVersion(context, HalVersion::V1_2);
     74 }
     75 
     76 bool prepare(IOperationExecutionContext* context) {
     77     Shape input = context->getInputShape(kInputTensor);
     78     Shape output = context->getOutputShape(kOutputTensor);
     79     NN_RET_CHECK(SetShape(input, &output));
     80     return context->setOutputShape(kOutputTensor, output);
     81 }
     82 
     83 bool executeAbs(IOperationExecutionContext* context) {
     84     return execute(context, std::abs);
     85 }
     86 
     87 bool executeExp(IOperationExecutionContext* context) {
     88     return execute(context, std::exp);
     89 }
     90 
     91 bool executeLog(IOperationExecutionContext* context) {
     92     return execute(context, std::log);
     93 }
     94 
     95 bool executeRsqrt(IOperationExecutionContext* context) {
     96     return execute(context, [](float x) { return 1.f / std::sqrt(x); });
     97 }
     98 
     99 bool executeSin(IOperationExecutionContext* context) {
    100     return execute(context, std::sin);
    101 }
    102 
    103 bool executeSqrt(IOperationExecutionContext* context) {
    104     return execute(context, std::sqrt);
    105 }
    106 
    107 }  // namespace elementwise
    108 
    109 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validate, elementwise::prepare,
    110                       elementwise::executeAbs);
    111 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
    112                       elementwise::executeExp);
    113 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
    114                       elementwise::executeLog);
    115 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
    116                       elementwise::executeRsqrt);
    117 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
    118                       elementwise::executeSin);
    119 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
    120                       elementwise::executeSqrt);
    121 
    122 }  // namespace nn
    123 }  // namespace android
    124