1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ 18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ 19 20 #include "fixedpoint.h" 21 #include "gemmlowp.h" 22 #include "../common.h" 23 #include "../types.h" 24 25 namespace android { 26 namespace nn { 27 namespace reference_ops { 28 29 template <FusedActivationFunctionType Ac> 30 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, 31 int32 input_offset, const uint8* filter_data, 32 const Dims<4>& filter_dims, int32 filter_offset, 33 const int32* bias_data, const Dims<4>& bias_dims, 34 int stride_width, int stride_height, 35 int pad_width, int pad_height, int depth_multiplier, 36 int32 output_offset, int32 output_multiplier, 37 int output_shift, int32 output_activation_min, 38 int32 output_activation_max, uint8* output_data, 39 const Dims<4>& output_dims) { 40 static_assert(Ac == FusedActivationFunctionType::kNone || 41 Ac == FusedActivationFunctionType::kRelu || 42 Ac == FusedActivationFunctionType::kRelu6 || 43 Ac == FusedActivationFunctionType::kRelu1, 44 ""); 45 DCHECK_LE(output_activation_min, output_activation_max); 46 if (Ac == FusedActivationFunctionType::kNone) { 47 DCHECK_EQ(output_activation_min, 0); 48 DCHECK_EQ(output_activation_max, 255); 49 } 50 const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); 51 const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); 52 const int input_height = ArraySize(input_dims, 2); 53 const int input_width = ArraySize(input_dims, 1); 54 const int input_depth = ArraySize(input_dims, 0); 55 const int filter_height = ArraySize(filter_dims, 2); 56 const int filter_width = ArraySize(filter_dims, 1); 57 const int output_height = ArraySize(output_dims, 2); 58 const int output_width = ArraySize(output_dims, 1); 59 DCHECK(output_depth == input_depth * depth_multiplier); 60 61 for (int b = 0; b < batches; ++b) { 62 for (int out_y = 0; out_y < output_height; ++out_y) { 63 for (int out_x = 0; out_x < output_width; ++out_x) { 64 for (int ic = 0; ic < input_depth; ++ic) { 65 for (int m = 0; m < depth_multiplier; m++) { 66 const int oc = m + ic * depth_multiplier; 67 const int in_x_origin = (out_x * stride_width) - pad_width; 68 const int in_y_origin = (out_y * stride_height) - pad_height; 69 int32 acc = 0; 70 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 71 for (int filter_x = 0; filter_x < filter_width; ++filter_x) { 72 const int in_x = in_x_origin + filter_x; 73 const int in_y = in_y_origin + filter_y; 74 // If the location is outside the bounds of the input image, 75 // use zero as a default value. 76 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && 77 (in_y < input_height)) { 78 int32 input_val = 79 input_data[Offset(input_dims, ic, in_x, in_y, b)]; 80 int32 filter_val = filter_data[Offset(filter_dims, oc, 81 filter_x, filter_y, 0)]; 82 acc += 83 (filter_val + filter_offset) * (input_val + input_offset); 84 } 85 } 86 } 87 if (bias_data) { 88 acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; 89 } 90 acc = MultiplyByQuantizedMultiplierSmallerThanOne( 91 acc, output_multiplier, output_shift); 92 acc += output_offset; 93 acc = std::max(acc, output_activation_min); 94 acc = std::min(acc, output_activation_max); 95 output_data[Offset(output_dims, oc, out_x, out_y, b)] = 96 static_cast<uint8>(acc); 97 } 98 } 99 } 100 } 101 } 102 } 103 104 } // end namespace reference_ops 105 } // namespace nn 106 } // namespace android 107 108 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ 109