Home | History | Annotate | Download | only in reference
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
     18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
     19 
     20 #include "fixedpoint.h"
     21 #include "gemmlowp.h"
     22 #include "../common.h"
     23 #include "../types.h"
     24 
     25 namespace android {
     26 namespace nn {
     27 namespace reference_ops {
     28 
     29 template <FusedActivationFunctionType Ac>
     30 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
     31                    int32 input_offset, const uint8* filter_data,
     32                    const Dims<4>& filter_dims, int32 filter_offset,
     33                    const int32* bias_data, const Dims<4>& bias_dims,
     34                    int stride_width, int stride_height,
     35                    int pad_width, int pad_height, int depth_multiplier,
     36                    int32 output_offset, int32 output_multiplier,
     37                    int output_shift, int32 output_activation_min,
     38                    int32 output_activation_max, uint8* output_data,
     39                    const Dims<4>& output_dims) {
     40   static_assert(Ac == FusedActivationFunctionType::kNone ||
     41                     Ac == FusedActivationFunctionType::kRelu ||
     42                     Ac == FusedActivationFunctionType::kRelu6 ||
     43                     Ac == FusedActivationFunctionType::kRelu1,
     44                 "");
     45   DCHECK_LE(output_activation_min, output_activation_max);
     46   if (Ac == FusedActivationFunctionType::kNone) {
     47     DCHECK_EQ(output_activation_min, 0);
     48     DCHECK_EQ(output_activation_max, 255);
     49   }
     50   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
     51   const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
     52   const int input_height = ArraySize(input_dims, 2);
     53   const int input_width = ArraySize(input_dims, 1);
     54   const int input_depth = ArraySize(input_dims, 0);
     55   const int filter_height = ArraySize(filter_dims, 2);
     56   const int filter_width = ArraySize(filter_dims, 1);
     57   const int output_height = ArraySize(output_dims, 2);
     58   const int output_width = ArraySize(output_dims, 1);
     59   DCHECK(output_depth == input_depth * depth_multiplier);
     60 
     61   for (int b = 0; b < batches; ++b) {
     62     for (int out_y = 0; out_y < output_height; ++out_y) {
     63       for (int out_x = 0; out_x < output_width; ++out_x) {
     64         for (int ic = 0; ic < input_depth; ++ic) {
     65           for (int m = 0; m < depth_multiplier; m++) {
     66             const int oc = m + ic * depth_multiplier;
     67             const int in_x_origin = (out_x * stride_width) - pad_width;
     68             const int in_y_origin = (out_y * stride_height) - pad_height;
     69             int32 acc = 0;
     70             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
     71               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
     72                 const int in_x = in_x_origin + filter_x;
     73                 const int in_y = in_y_origin + filter_y;
     74                 // If the location is outside the bounds of the input image,
     75                 // use zero as a default value.
     76                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
     77                     (in_y < input_height)) {
     78                   int32 input_val =
     79                       input_data[Offset(input_dims, ic, in_x, in_y, b)];
     80                   int32 filter_val = filter_data[Offset(filter_dims, oc,
     81                                                         filter_x, filter_y, 0)];
     82                   acc +=
     83                       (filter_val + filter_offset) * (input_val + input_offset);
     84                 }
     85               }
     86             }
     87             if (bias_data) {
     88               acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
     89             }
     90             acc = MultiplyByQuantizedMultiplierSmallerThanOne(
     91                 acc, output_multiplier, output_shift);
     92             acc += output_offset;
     93             acc = std::max(acc, output_activation_min);
     94             acc = std::min(acc, output_activation_max);
     95             output_data[Offset(output_dims, oc, out_x, out_y, b)] =
     96                 static_cast<uint8>(acc);
     97           }
     98         }
     99       }
    100     }
    101   }
    102 }
    103 
    104 }  // end namespace reference_ops
    105 }  // namespace nn
    106 }  // namespace android
    107 
    108 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
    109