Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
     18 #define ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
     19 
     20 #include "HalInterfaces.h"
     21 #include "OperationsUtils.h"
     22 
     23 namespace android {
     24 namespace nn {
     25 
     26 // Encapsulates an operation implementation.
     27 struct OperationRegistration {
     28     OperationType type;
     29     const char* name;
     30 
     31     // Validates operand types, shapes, and any values known during graph creation.
     32     std::function<bool(const IOperationValidationContext*)> validate;
     33 
     34     // prepare is called when the inputs this operation depends on have been
     35     // computed. Typically, prepare does any remaining validation and sets
     36     // output shapes via context->setOutputShape(...).
     37     std::function<bool(IOperationExecutionContext*)> prepare;
     38 
     39     // Executes the operation, reading from context->getInputBuffer(...)
     40     // and writing to context->getOutputBuffer(...).
     41     std::function<bool(IOperationExecutionContext*)> execute;
     42 
     43     struct Flag {
     44         // Whether the operation allows at least one operand to be omitted.
     45         bool allowOmittedOperand = false;
     46         // Whether the operation allows at least one input operand to be a zero-sized tensor.
     47         bool allowZeroSizedInput = false;
     48     } flags;
     49 
     50     OperationRegistration(OperationType type, const char* name,
     51                           std::function<bool(const IOperationValidationContext*)> validate,
     52                           std::function<bool(IOperationExecutionContext*)> prepare,
     53                           std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
     54         : type(type),
     55           name(name),
     56           validate(validate),
     57           prepare(prepare),
     58           execute(execute),
     59           flags(flags) {}
     60 };
     61 
     62 // A registry of operation implementations.
     63 class IOperationResolver {
     64    public:
     65     virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
     66     virtual ~IOperationResolver() {}
     67 };
     68 
     69 // A registry of builtin operation implementations.
     70 //
     71 // Note that some operations bypass BuiltinOperationResolver (b/124041202).
     72 //
     73 // Usage:
     74 //   const OperationRegistration* operationRegistration =
     75 //           BuiltinOperationResolver::get()->findOperation(operationType);
     76 //   NN_RET_CHECK(operationRegistration != nullptr);
     77 //   NN_RET_CHECK(operationRegistration->validate != nullptr);
     78 //   NN_RET_CHECK(operationRegistration->validate(&context));
     79 //
     80 class BuiltinOperationResolver : public IOperationResolver {
     81     DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
     82 
     83    public:
     84     static const BuiltinOperationResolver* get() {
     85         static BuiltinOperationResolver instance;
     86         return &instance;
     87     }
     88 
     89     const OperationRegistration* findOperation(OperationType operationType) const override;
     90 
     91    private:
     92     BuiltinOperationResolver();
     93 
     94     void registerOperation(const OperationRegistration* operationRegistration);
     95 
     96     const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
     97 };
     98 
     99 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
    100 // OperationResolver.
    101 //
    102 // Usage:
    103 // (check OperationRegistration::Flag for available fields and default values.)
    104 //
    105 // - With default flags.
    106 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
    107 //                         foo_op::prepare, foo_op::execute);
    108 //
    109 // - With a customized flag.
    110 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
    111 //                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
    112 //
    113 // - With multiple customized flags.
    114 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
    115 //                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
    116 //                         .allowZeroSizedInput = true);
    117 //
    118 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
    119 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)     \
    120     const OperationRegistration* register_##identifier() {                                    \
    121         static OperationRegistration registration(OperationType::identifier, operationName,   \
    122                                                   validate, prepare, execute, {__VA_ARGS__}); \
    123         return &registration;                                                                 \
    124     }
    125 #else
    126 // This version ignores CPU execution logic (prepare and execute).
    127 // The compiler is supposed to omit that code so that only validation logic
    128 // makes it into libneuralnetworks_utils.
    129 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
    130                               ...)                                                                 \
    131     const OperationRegistration* register_##identifier() {                                         \
    132         static OperationRegistration registration(OperationType::identifier, operationName,        \
    133                                                   validate, nullptr, nullptr, {__VA_ARGS__});      \
    134         return &registration;                                                                      \
    135     }
    136 #endif
    137 
    138 }  // namespace nn
    139 }  // namespace android
    140 
    141 #endif  // ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
    142