1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <stdbool.h> 13 #include <assert.h> 14 #include <pmmintrin.h> 15 16 #include "config/av1_rtcd.h" 17 #include "av1/encoder/ml.h" 18 19 // In order to avoid the high-latency of swapping between FPU and SIMD 20 // operations, we keep the result in a 128-bit register even though we only 21 // care about a single value. 22 static void nn_propagate_8to1(const float *const inputs, 23 const float *const weights, 24 __m128 *const output) { 25 const __m128 inputs_h = _mm_loadu_ps(&inputs[4]); 26 const __m128 inputs_l = _mm_loadu_ps(inputs); 27 28 const __m128 weights_h = _mm_loadu_ps(&weights[4]); 29 const __m128 weights_l = _mm_loadu_ps(weights); 30 31 const __m128 mul_h = _mm_mul_ps(inputs_h, weights_h); 32 const __m128 mul_l = _mm_mul_ps(inputs_l, weights_l); 33 // [7 6 5 4] [3 2 1 0] (weight and input indices) 34 35 const __m128 vadd = _mm_add_ps(mul_l, mul_h); 36 // [7+3 6+2 5+1 4+0] 37 const __m128 hadd1 = _mm_hadd_ps(vadd, vadd); 38 // [7+6+3+2 5+4+1+0 7+6+3+2 5+4+1+0] 39 const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1); 40 // [7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0] 41 *output = _mm_add_ps(*output, hadd2); 42 } 43 44 static void nn_propagate_4to1(const float *const inputs, 45 const float *const weights, 46 __m128 *const output) { 47 const __m128 inputs128 = _mm_loadu_ps(inputs); 48 49 const __m128 weights128 = _mm_loadu_ps(weights); 50 51 const __m128 mul = _mm_mul_ps(inputs128, weights128); 52 // [3 2 1 0] (weight and input indices) 53 54 const __m128 hadd1 = _mm_hadd_ps(mul, mul); 55 // [3+2 1+0 3+2 1+0] 56 const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1); 57 // [3+2+1+0 3+2+1+0 3+2+1+0 3+2+1+0] 58 *output = _mm_add_ps(*output, hadd2); 59 } 60 61 static void nn_propagate_4to4(const float *const inputs, 62 const float *const weights, __m128 *const outputs, 63 const int num_inputs) { 64 const __m128 inputs128 = _mm_loadu_ps(inputs); 65 66 __m128 hadd[2]; 67 for (int i = 0; i < 2; i++) { // For each pair of outputs 68 const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]); 69 const __m128 mul0 = _mm_mul_ps(weight0, inputs128); 70 const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]); 71 const __m128 mul1 = _mm_mul_ps(weight1, inputs128); 72 hadd[i] = _mm_hadd_ps(mul0, mul1); 73 } 74 // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices) 75 // hadd[1] = [15+14 13+12 11+10 9+8] 76 77 const __m128 hh = _mm_hadd_ps(hadd[0], hadd[1]); 78 // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0] 79 80 *outputs = _mm_add_ps(*outputs, hh); 81 } 82 83 static void nn_propagate_4to8(const float *const inputs, 84 const float *const weights, __m128 *const out_h, 85 __m128 *const out_l, const int num_inputs) { 86 const __m128 inputs128 = _mm_loadu_ps(inputs); 87 88 __m128 hadd[4]; 89 for (int i = 0; i < 4; i++) { // For each pair of outputs 90 const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]); 91 const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]); 92 const __m128 mul0 = _mm_mul_ps(inputs128, weight0); 93 const __m128 mul1 = _mm_mul_ps(inputs128, weight1); 94 hadd[i] = _mm_hadd_ps(mul0, mul1); 95 } 96 // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices) 97 // hadd[1] = [15+14 13+12 11+10 9+8] 98 // hadd[2] = [23+22 21+20 19+18 17+16] 99 // hadd[3] = [31+30 29+28 27+26 25+24] 100 101 const __m128 hh0 = _mm_hadd_ps(hadd[0], hadd[1]); 102 // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0] 103 const __m128 hh1 = _mm_hadd_ps(hadd[2], hadd[3]); 104 // [31+30+29+28 27+26+25+24 23+22+21+20 19+18+17+16] 105 106 *out_h = _mm_add_ps(*out_h, hh1); 107 *out_l = _mm_add_ps(*out_l, hh0); 108 } 109 110 static void nn_propagate_8to4(const float *const inputs, 111 const float *const weights, __m128 *const outputs, 112 const int num_inputs) { 113 const __m128 inputs_h = _mm_loadu_ps(inputs + 4); 114 const __m128 inputs_l = _mm_loadu_ps(inputs); 115 // [7 6 5 4] [3 2 1 0] (input indices) 116 117 __m128 add[4]; 118 for (int i = 0; i < 4; i++) { // For each output: 119 const __m128 weight_h = _mm_loadu_ps(&weights[i * num_inputs + 4]); 120 const __m128 weight_l = _mm_loadu_ps(&weights[i * num_inputs]); 121 const __m128 mul_h = _mm_mul_ps(inputs_h, weight_h); 122 const __m128 mul_l = _mm_mul_ps(inputs_l, weight_l); 123 add[i] = _mm_add_ps(mul_l, mul_h); 124 } 125 // add[0] = [7+3 6+2 5+1 4+0] 126 // add[1] = [15+11 14+10 13+9 12+8] 127 // add[2] = [23+19 22+18 21+17 20+16] 128 // add[3] = [31+27 30+26 29+25 28+24] 129 130 const __m128 hadd_h = _mm_hadd_ps(add[2], add[3]); 131 // [31+30+27+26 29+28+25+24 23+22+19+18 21+20+17+16] 132 const __m128 hadd_l = _mm_hadd_ps(add[0], add[1]); 133 // [15+14+11+10 13+12+9+8 7+6+3+2 5+4+1+0] 134 135 const __m128 haddhadd = _mm_hadd_ps(hadd_l, hadd_h); 136 // [31+30+29+28+27+26+25+24 23+22+21+20+19+18+17+16 137 // 15+14+13+12+11+10+9+8 7+6+5+4+3+2+1+0] 138 139 *outputs = _mm_add_ps(*outputs, haddhadd); 140 } 141 142 static void nn_activate8(__m128 *out_h, __m128 *out_l) { 143 const __m128 zero = _mm_setzero_ps(); 144 *out_h = _mm_max_ps(*out_h, zero); 145 *out_l = _mm_max_ps(*out_l, zero); 146 } 147 148 static void nn_activate4(__m128 *x) { *x = _mm_max_ps(*x, _mm_setzero_ps()); } 149 150 // Calculate prediction based on the given input features and neural net config. 151 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden 152 // layer. 153 void av1_nn_predict_sse3(const float *input_nodes, 154 const NN_CONFIG *const nn_config, 155 float *const output) { 156 float buf[2][NN_MAX_NODES_PER_LAYER]; 157 int buf_index = 0; 158 int num_inputs = nn_config->num_inputs; 159 160 // Hidden layers, except the final iteration is the output layer. 161 for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) { 162 const float *layer_weights = nn_config->weights[layer]; 163 const float *layer_bias = nn_config->bias[layer]; 164 bool output_layer = (layer == nn_config->num_hidden_layers); 165 float *const output_nodes = output_layer ? output : buf[buf_index]; 166 const int num_outputs = output_layer ? nn_config->num_outputs 167 : nn_config->num_hidden_nodes[layer]; 168 169 if (num_inputs % 4 == 0 && num_outputs % 8 == 0) { 170 for (int out = 0; out < num_outputs; out += 8) { 171 __m128 out_h = _mm_loadu_ps(&layer_bias[out + 4]); 172 __m128 out_l = _mm_loadu_ps(&layer_bias[out]); 173 for (int in = 0; in < num_inputs; in += 4) { 174 nn_propagate_4to8(&input_nodes[in], 175 &layer_weights[out * num_inputs + in], &out_h, 176 &out_l, num_inputs); 177 } 178 if (!output_layer) nn_activate8(&out_h, &out_l); 179 _mm_storeu_ps(&output_nodes[out + 4], out_h); 180 _mm_storeu_ps(&output_nodes[out], out_l); 181 } 182 } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) { 183 for (int out = 0; out < num_outputs; out += 4) { 184 __m128 outputs = _mm_loadu_ps(&layer_bias[out]); 185 for (int in = 0; in < num_inputs; in += 8) { 186 nn_propagate_8to4(&input_nodes[in], 187 &layer_weights[out * num_inputs + in], &outputs, 188 num_inputs); 189 } 190 if (!output_layer) nn_activate4(&outputs); 191 _mm_storeu_ps(&output_nodes[out], outputs); 192 } 193 } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) { 194 for (int out = 0; out < num_outputs; out += 4) { 195 __m128 outputs = _mm_loadu_ps(&layer_bias[out]); 196 for (int in = 0; in < num_inputs; in += 4) { 197 nn_propagate_4to4(&input_nodes[in], 198 &layer_weights[out * num_inputs + in], &outputs, 199 num_inputs); 200 } 201 if (!output_layer) nn_activate4(&outputs); 202 _mm_storeu_ps(&output_nodes[out], outputs); 203 } 204 } else if (num_inputs % 8 == 0) { 205 for (int out = 0; out < num_outputs; out++) { 206 __m128 total = _mm_load1_ps(&layer_bias[out]); 207 for (int in = 0; in < num_inputs; in += 8) { 208 nn_propagate_8to1(&input_nodes[in], 209 &layer_weights[out * num_inputs + in], &total); 210 } 211 if (!output_layer) nn_activate4(&total); 212 output_nodes[out] = _mm_cvtss_f32(total); 213 } 214 } else if (num_inputs % 4 == 0) { 215 for (int out = 0; out < num_outputs; out++) { 216 __m128 total = _mm_load1_ps(&layer_bias[out]); 217 for (int in = 0; in < num_inputs; in += 4) { 218 nn_propagate_4to1(&input_nodes[in], 219 &layer_weights[out * num_inputs + in], &total); 220 } 221 if (!output_layer) nn_activate4(&total); 222 output_nodes[out] = _mm_cvtss_f32(total); 223 } 224 } else { 225 // Use SSE instructions for scalar operations to avoid the latency of 226 // swapping between SIMD and FPU modes. 227 for (int out = 0; out < num_outputs; out++) { 228 __m128 total = _mm_load1_ps(&layer_bias[out]); 229 for (int in_node = 0; in_node < num_inputs; in_node++) { 230 __m128 input = _mm_load1_ps(&input_nodes[in_node]); 231 __m128 weight = 232 _mm_load1_ps(&layer_weights[num_inputs * out + in_node]); 233 total = _mm_add_ps(total, _mm_mul_ps(input, weight)); 234 } 235 if (!output_layer) nn_activate4(&total); 236 output_nodes[out] = _mm_cvtss_f32(total); 237 } 238 } 239 input_nodes = output_nodes; 240 num_inputs = num_outputs; 241 buf_index = 1 - buf_index; 242 } 243 } 244