Home | History | Annotate | Download | only in x86
      1 /*
      2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
      3  *
      4  * This source code is subject to the terms of the BSD 2 Clause License and
      5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6  * was not distributed with this source code in the LICENSE file, you can
      7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8  * Media Patent License 1.0 was not distributed with this source code in the
      9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10  */
     11 
     12 #include <stdbool.h>
     13 #include <assert.h>
     14 #include <pmmintrin.h>
     15 
     16 #include "config/av1_rtcd.h"
     17 #include "av1/encoder/ml.h"
     18 
     19 // In order to avoid the high-latency of swapping between FPU and SIMD
     20 // operations, we keep the result in a 128-bit register even though we only
     21 // care about a single value.
     22 static void nn_propagate_8to1(const float *const inputs,
     23                               const float *const weights,
     24                               __m128 *const output) {
     25   const __m128 inputs_h = _mm_loadu_ps(&inputs[4]);
     26   const __m128 inputs_l = _mm_loadu_ps(inputs);
     27 
     28   const __m128 weights_h = _mm_loadu_ps(&weights[4]);
     29   const __m128 weights_l = _mm_loadu_ps(weights);
     30 
     31   const __m128 mul_h = _mm_mul_ps(inputs_h, weights_h);
     32   const __m128 mul_l = _mm_mul_ps(inputs_l, weights_l);
     33   // [7 6 5 4] [3 2 1 0] (weight and input indices)
     34 
     35   const __m128 vadd = _mm_add_ps(mul_l, mul_h);
     36   // [7+3 6+2 5+1 4+0]
     37   const __m128 hadd1 = _mm_hadd_ps(vadd, vadd);
     38   // [7+6+3+2 5+4+1+0 7+6+3+2 5+4+1+0]
     39   const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
     40   // [7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0]
     41   *output = _mm_add_ps(*output, hadd2);
     42 }
     43 
     44 static void nn_propagate_4to1(const float *const inputs,
     45                               const float *const weights,
     46                               __m128 *const output) {
     47   const __m128 inputs128 = _mm_loadu_ps(inputs);
     48 
     49   const __m128 weights128 = _mm_loadu_ps(weights);
     50 
     51   const __m128 mul = _mm_mul_ps(inputs128, weights128);
     52   // [3 2 1 0] (weight and input indices)
     53 
     54   const __m128 hadd1 = _mm_hadd_ps(mul, mul);
     55   // [3+2 1+0 3+2 1+0]
     56   const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
     57   // [3+2+1+0 3+2+1+0 3+2+1+0 3+2+1+0]
     58   *output = _mm_add_ps(*output, hadd2);
     59 }
     60 
     61 static void nn_propagate_4to4(const float *const inputs,
     62                               const float *const weights, __m128 *const outputs,
     63                               const int num_inputs) {
     64   const __m128 inputs128 = _mm_loadu_ps(inputs);
     65 
     66   __m128 hadd[2];
     67   for (int i = 0; i < 2; i++) {  // For each pair of outputs
     68     const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
     69     const __m128 mul0 = _mm_mul_ps(weight0, inputs128);
     70     const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
     71     const __m128 mul1 = _mm_mul_ps(weight1, inputs128);
     72     hadd[i] = _mm_hadd_ps(mul0, mul1);
     73   }
     74   // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
     75   // hadd[1] = [15+14 13+12 11+10 9+8]
     76 
     77   const __m128 hh = _mm_hadd_ps(hadd[0], hadd[1]);
     78   // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
     79 
     80   *outputs = _mm_add_ps(*outputs, hh);
     81 }
     82 
     83 static void nn_propagate_4to8(const float *const inputs,
     84                               const float *const weights, __m128 *const out_h,
     85                               __m128 *const out_l, const int num_inputs) {
     86   const __m128 inputs128 = _mm_loadu_ps(inputs);
     87 
     88   __m128 hadd[4];
     89   for (int i = 0; i < 4; i++) {  // For each pair of outputs
     90     const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
     91     const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
     92     const __m128 mul0 = _mm_mul_ps(inputs128, weight0);
     93     const __m128 mul1 = _mm_mul_ps(inputs128, weight1);
     94     hadd[i] = _mm_hadd_ps(mul0, mul1);
     95   }
     96   // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
     97   // hadd[1] = [15+14 13+12 11+10 9+8]
     98   // hadd[2] = [23+22 21+20 19+18 17+16]
     99   // hadd[3] = [31+30 29+28 27+26 25+24]
    100 
    101   const __m128 hh0 = _mm_hadd_ps(hadd[0], hadd[1]);
    102   // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
    103   const __m128 hh1 = _mm_hadd_ps(hadd[2], hadd[3]);
    104   // [31+30+29+28 27+26+25+24 23+22+21+20 19+18+17+16]
    105 
    106   *out_h = _mm_add_ps(*out_h, hh1);
    107   *out_l = _mm_add_ps(*out_l, hh0);
    108 }
    109 
    110 static void nn_propagate_8to4(const float *const inputs,
    111                               const float *const weights, __m128 *const outputs,
    112                               const int num_inputs) {
    113   const __m128 inputs_h = _mm_loadu_ps(inputs + 4);
    114   const __m128 inputs_l = _mm_loadu_ps(inputs);
    115   // [7 6 5 4] [3 2 1 0] (input indices)
    116 
    117   __m128 add[4];
    118   for (int i = 0; i < 4; i++) {  // For each output:
    119     const __m128 weight_h = _mm_loadu_ps(&weights[i * num_inputs + 4]);
    120     const __m128 weight_l = _mm_loadu_ps(&weights[i * num_inputs]);
    121     const __m128 mul_h = _mm_mul_ps(inputs_h, weight_h);
    122     const __m128 mul_l = _mm_mul_ps(inputs_l, weight_l);
    123     add[i] = _mm_add_ps(mul_l, mul_h);
    124   }
    125   // add[0] = [7+3 6+2 5+1 4+0]
    126   // add[1] = [15+11 14+10 13+9 12+8]
    127   // add[2] = [23+19 22+18 21+17 20+16]
    128   // add[3] = [31+27 30+26 29+25 28+24]
    129 
    130   const __m128 hadd_h = _mm_hadd_ps(add[2], add[3]);
    131   // [31+30+27+26 29+28+25+24 23+22+19+18 21+20+17+16]
    132   const __m128 hadd_l = _mm_hadd_ps(add[0], add[1]);
    133   // [15+14+11+10 13+12+9+8 7+6+3+2 5+4+1+0]
    134 
    135   const __m128 haddhadd = _mm_hadd_ps(hadd_l, hadd_h);
    136   // [31+30+29+28+27+26+25+24 23+22+21+20+19+18+17+16
    137   //  15+14+13+12+11+10+9+8 7+6+5+4+3+2+1+0]
    138 
    139   *outputs = _mm_add_ps(*outputs, haddhadd);
    140 }
    141 
    142 static void nn_activate8(__m128 *out_h, __m128 *out_l) {
    143   const __m128 zero = _mm_setzero_ps();
    144   *out_h = _mm_max_ps(*out_h, zero);
    145   *out_l = _mm_max_ps(*out_l, zero);
    146 }
    147 
    148 static void nn_activate4(__m128 *x) { *x = _mm_max_ps(*x, _mm_setzero_ps()); }
    149 
    150 // Calculate prediction based on the given input features and neural net config.
    151 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
    152 // layer.
    153 void av1_nn_predict_sse3(const float *input_nodes,
    154                          const NN_CONFIG *const nn_config,
    155                          float *const output) {
    156   float buf[2][NN_MAX_NODES_PER_LAYER];
    157   int buf_index = 0;
    158   int num_inputs = nn_config->num_inputs;
    159 
    160   // Hidden layers, except the final iteration is the output layer.
    161   for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
    162     const float *layer_weights = nn_config->weights[layer];
    163     const float *layer_bias = nn_config->bias[layer];
    164     bool output_layer = (layer == nn_config->num_hidden_layers);
    165     float *const output_nodes = output_layer ? output : buf[buf_index];
    166     const int num_outputs = output_layer ? nn_config->num_outputs
    167                                          : nn_config->num_hidden_nodes[layer];
    168 
    169     if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
    170       for (int out = 0; out < num_outputs; out += 8) {
    171         __m128 out_h = _mm_loadu_ps(&layer_bias[out + 4]);
    172         __m128 out_l = _mm_loadu_ps(&layer_bias[out]);
    173         for (int in = 0; in < num_inputs; in += 4) {
    174           nn_propagate_4to8(&input_nodes[in],
    175                             &layer_weights[out * num_inputs + in], &out_h,
    176                             &out_l, num_inputs);
    177         }
    178         if (!output_layer) nn_activate8(&out_h, &out_l);
    179         _mm_storeu_ps(&output_nodes[out + 4], out_h);
    180         _mm_storeu_ps(&output_nodes[out], out_l);
    181       }
    182     } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
    183       for (int out = 0; out < num_outputs; out += 4) {
    184         __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
    185         for (int in = 0; in < num_inputs; in += 8) {
    186           nn_propagate_8to4(&input_nodes[in],
    187                             &layer_weights[out * num_inputs + in], &outputs,
    188                             num_inputs);
    189         }
    190         if (!output_layer) nn_activate4(&outputs);
    191         _mm_storeu_ps(&output_nodes[out], outputs);
    192       }
    193     } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
    194       for (int out = 0; out < num_outputs; out += 4) {
    195         __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
    196         for (int in = 0; in < num_inputs; in += 4) {
    197           nn_propagate_4to4(&input_nodes[in],
    198                             &layer_weights[out * num_inputs + in], &outputs,
    199                             num_inputs);
    200         }
    201         if (!output_layer) nn_activate4(&outputs);
    202         _mm_storeu_ps(&output_nodes[out], outputs);
    203       }
    204     } else if (num_inputs % 8 == 0) {
    205       for (int out = 0; out < num_outputs; out++) {
    206         __m128 total = _mm_load1_ps(&layer_bias[out]);
    207         for (int in = 0; in < num_inputs; in += 8) {
    208           nn_propagate_8to1(&input_nodes[in],
    209                             &layer_weights[out * num_inputs + in], &total);
    210         }
    211         if (!output_layer) nn_activate4(&total);
    212         output_nodes[out] = _mm_cvtss_f32(total);
    213       }
    214     } else if (num_inputs % 4 == 0) {
    215       for (int out = 0; out < num_outputs; out++) {
    216         __m128 total = _mm_load1_ps(&layer_bias[out]);
    217         for (int in = 0; in < num_inputs; in += 4) {
    218           nn_propagate_4to1(&input_nodes[in],
    219                             &layer_weights[out * num_inputs + in], &total);
    220         }
    221         if (!output_layer) nn_activate4(&total);
    222         output_nodes[out] = _mm_cvtss_f32(total);
    223       }
    224     } else {
    225       // Use SSE instructions for scalar operations to avoid the latency of
    226       // swapping between SIMD and FPU modes.
    227       for (int out = 0; out < num_outputs; out++) {
    228         __m128 total = _mm_load1_ps(&layer_bias[out]);
    229         for (int in_node = 0; in_node < num_inputs; in_node++) {
    230           __m128 input = _mm_load1_ps(&input_nodes[in_node]);
    231           __m128 weight =
    232               _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
    233           total = _mm_add_ps(total, _mm_mul_ps(input, weight));
    234         }
    235         if (!output_layer) nn_activate4(&total);
    236         output_nodes[out] = _mm_cvtss_f32(total);
    237       }
    238     }
    239     input_nodes = output_nodes;
    240     num_inputs = num_outputs;
    241     buf_index = 1 - buf_index;
    242   }
    243 }
    244