Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "OperationResolver"
     18 
     19 #include "OperationResolver.h"
     20 
     21 #include "NeuralNetworks.h"
     22 
     23 namespace android {
     24 namespace nn {
     25 
     26 // TODO(b/119608412): Find a way to not reference every operation here.
     27 const OperationRegistration* register_ABS();
     28 const OperationRegistration* register_ADD();
     29 const OperationRegistration* register_AVERAGE_POOL_2D();
     30 const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
     31 const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN();
     32 const OperationRegistration* register_BOX_WITH_NMS_LIMIT();
     33 const OperationRegistration* register_CHANNEL_SHUFFLE();
     34 const OperationRegistration* register_CONCATENATION();
     35 const OperationRegistration* register_CONV_2D();
     36 const OperationRegistration* register_DEQUANTIZE();
     37 const OperationRegistration* register_DETECTION_POSTPROCESSING();
     38 const OperationRegistration* register_DIV();
     39 const OperationRegistration* register_EQUAL();
     40 const OperationRegistration* register_EXP();
     41 const OperationRegistration* register_FULLY_CONNECTED();
     42 const OperationRegistration* register_GATHER();
     43 const OperationRegistration* register_GENERATE_PROPOSALS();
     44 const OperationRegistration* register_GREATER();
     45 const OperationRegistration* register_GREATER_EQUAL();
     46 const OperationRegistration* register_HEATMAP_MAX_KEYPOINT();
     47 const OperationRegistration* register_INSTANCE_NORMALIZATION();
     48 const OperationRegistration* register_L2_NORMALIZATION();
     49 const OperationRegistration* register_L2_POOL_2D();
     50 const OperationRegistration* register_LESS();
     51 const OperationRegistration* register_LESS_EQUAL();
     52 const OperationRegistration* register_LOG();
     53 const OperationRegistration* register_LOGICAL_AND();
     54 const OperationRegistration* register_LOGICAL_NOT();
     55 const OperationRegistration* register_LOGICAL_OR();
     56 const OperationRegistration* register_LOGISTIC();
     57 const OperationRegistration* register_LOG_SOFTMAX();
     58 const OperationRegistration* register_MAX_POOL_2D();
     59 const OperationRegistration* register_MUL();
     60 const OperationRegistration* register_NEG();
     61 const OperationRegistration* register_NOT_EQUAL();
     62 const OperationRegistration* register_PRELU();
     63 const OperationRegistration* register_QUANTIZE();
     64 const OperationRegistration* register_REDUCE_ALL();
     65 const OperationRegistration* register_REDUCE_ANY();
     66 const OperationRegistration* register_REDUCE_MAX();
     67 const OperationRegistration* register_REDUCE_MIN();
     68 const OperationRegistration* register_REDUCE_PROD();
     69 const OperationRegistration* register_REDUCE_SUM();
     70 const OperationRegistration* register_RELU();
     71 const OperationRegistration* register_RELU1();
     72 const OperationRegistration* register_RELU6();
     73 const OperationRegistration* register_RESIZE_BILINEAR();
     74 const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR();
     75 const OperationRegistration* register_ROI_ALIGN();
     76 const OperationRegistration* register_ROI_POOLING();
     77 const OperationRegistration* register_RSQRT();
     78 const OperationRegistration* register_SELECT();
     79 const OperationRegistration* register_SIN();
     80 const OperationRegistration* register_SLICE();
     81 const OperationRegistration* register_SOFTMAX();
     82 const OperationRegistration* register_SQRT();
     83 const OperationRegistration* register_SUB();
     84 const OperationRegistration* register_TANH();
     85 const OperationRegistration* register_TRANSPOSE();
     86 const OperationRegistration* register_TRANSPOSE_CONV_2D();
     87 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM();
     88 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN();
     89 
     90 BuiltinOperationResolver::BuiltinOperationResolver() {
     91     registerOperation(register_ABS());
     92     registerOperation(register_ADD());
     93     registerOperation(register_AVERAGE_POOL_2D());
     94     registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
     95     registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN());
     96     registerOperation(register_BOX_WITH_NMS_LIMIT());
     97     registerOperation(register_CHANNEL_SHUFFLE());
     98     registerOperation(register_CONCATENATION());
     99     registerOperation(register_CONV_2D());
    100     registerOperation(register_DEQUANTIZE());
    101     registerOperation(register_DETECTION_POSTPROCESSING());
    102     registerOperation(register_DIV());
    103     registerOperation(register_EQUAL());
    104     registerOperation(register_EXP());
    105     registerOperation(register_FULLY_CONNECTED());
    106     registerOperation(register_GATHER());
    107     registerOperation(register_GENERATE_PROPOSALS());
    108     registerOperation(register_GREATER());
    109     registerOperation(register_GREATER_EQUAL());
    110     registerOperation(register_HEATMAP_MAX_KEYPOINT());
    111     registerOperation(register_INSTANCE_NORMALIZATION());
    112     registerOperation(register_L2_NORMALIZATION());
    113     registerOperation(register_L2_POOL_2D());
    114     registerOperation(register_LESS());
    115     registerOperation(register_LESS_EQUAL());
    116     registerOperation(register_LOG());
    117     registerOperation(register_LOGICAL_AND());
    118     registerOperation(register_LOGICAL_NOT());
    119     registerOperation(register_LOGICAL_OR());
    120     registerOperation(register_LOGISTIC());
    121     registerOperation(register_LOG_SOFTMAX());
    122     registerOperation(register_MAX_POOL_2D());
    123     registerOperation(register_MUL());
    124     registerOperation(register_NEG());
    125     registerOperation(register_NOT_EQUAL());
    126     registerOperation(register_PRELU());
    127     registerOperation(register_QUANTIZE());
    128     registerOperation(register_REDUCE_ALL());
    129     registerOperation(register_REDUCE_ANY());
    130     registerOperation(register_REDUCE_MAX());
    131     registerOperation(register_REDUCE_MIN());
    132     registerOperation(register_REDUCE_PROD());
    133     registerOperation(register_REDUCE_SUM());
    134     registerOperation(register_RELU());
    135     registerOperation(register_RELU1());
    136     registerOperation(register_RELU6());
    137     registerOperation(register_RESIZE_BILINEAR());
    138     registerOperation(register_RESIZE_NEAREST_NEIGHBOR());
    139     registerOperation(register_ROI_ALIGN());
    140     registerOperation(register_ROI_POOLING());
    141     registerOperation(register_RSQRT());
    142     registerOperation(register_SELECT());
    143     registerOperation(register_SIN());
    144     registerOperation(register_SLICE());
    145     registerOperation(register_SOFTMAX());
    146     registerOperation(register_SQRT());
    147     registerOperation(register_SUB());
    148     registerOperation(register_TANH());
    149     registerOperation(register_TRANSPOSE());
    150     registerOperation(register_TRANSPOSE_CONV_2D());
    151     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM());
    152     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN());
    153 }
    154 
    155 const OperationRegistration* BuiltinOperationResolver::findOperation(
    156         OperationType operationType) const {
    157     auto index = static_cast<int32_t>(operationType);
    158     if (index < 0 || index >= kNumberOfOperationTypes) {
    159         return nullptr;
    160     }
    161     return mRegistrations[index];
    162 }
    163 
    164 void BuiltinOperationResolver::registerOperation(
    165         const OperationRegistration* operationRegistration) {
    166     CHECK(operationRegistration != nullptr);
    167     auto index = static_cast<int32_t>(operationRegistration->type);
    168     CHECK_LE(0, index);
    169     CHECK_LT(index, kNumberOfOperationTypes);
    170     CHECK(mRegistrations[index] == nullptr);
    171     mRegistrations[index] = operationRegistration;
    172 }
    173 
    174 }  // namespace nn
    175 }  // namespace android
    176