Home | History | Annotate | Download | only in encoder
      1 /*
      2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
      3  *
      4  * This source code is subject to the terms of the BSD 2 Clause License and
      5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6  * was not distributed with this source code in the LICENSE file, you can
      7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8  * Media Patent License 1.0 was not distributed with this source code in the
      9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10  */
     11 
     12 #include <float.h>
     13 
     14 #include "aom_ports/system_state.h"
     15 
     16 #include "av1/common/enums.h"
     17 #include "av1/common/reconinter.h"
     18 
     19 #include "av1/encoder/encoder.h"
     20 #include "av1/encoder/partition_model_weights.h"
     21 #include "av1/encoder/partition_strategy.h"
     22 #include "av1/encoder/rdopt.h"
     23 
     24 // Performs a simple_motion_search with a single reference frame and extract
     25 // the variance of residues. Here features is assumed to be a length 6 array.
     26 // After this function is called, we will store the following in to features:
     27 // features[0] = log(1 + dc_q**2/256)
     28 // features[1] = log(1 + variance_of_residue)
     29 // for i in [2, 3, 4, 5]:
     30 //  features[i] = log(1 + variance_of_residue_in_block[i]/variance_of_residue)
     31 static void get_res_var_features(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
     32                                  int mi_col, BLOCK_SIZE bsize,
     33                                  float *features) {
     34   // TODO(chiyotsai (at) google.com): The data this model trained on did not also use
     35   // SIMPLE_TRANSLATION to build the inter_predictor. Retraining and tuning the
     36   // model with the correct data should give better performance.
     37   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
     38 
     39   MACROBLOCKD *xd = &x->e_mbd;
     40 
     41   // Perform a single motion search in Y_PLANE to make a prediction
     42   const int use_subpixel = 0;
     43 
     44   // Start getting the features
     45   int f_idx = 0;
     46 
     47   // Q_INDEX
     48   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
     49   aom_clear_system_state();
     50   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
     51 
     52   // VARIANCE
     53   unsigned int sse = 0;
     54   unsigned int var = 0;
     55   const MV ref_mv_full = { .row = 0, .col = 0 };
     56   av1_simple_motion_sse_var(cpi, x, mi_row, mi_col, bsize, ref_mv_full,
     57                             use_subpixel, &sse, &var);
     58   aom_clear_system_state();
     59   features[f_idx++] = logf(1.0f + (float)var);
     60 
     61   // Regional
     62   const uint8_t *src = x->plane[0].src.buf;
     63   const int src_stride = x->plane[0].src.stride;
     64   const uint8_t *dst = xd->plane[0].dst.buf;
     65   const int dst_stride = xd->plane[0].dst.stride;
     66   const int bw = block_size_wide[bsize];
     67   const int bh = block_size_high[bsize];
     68   const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
     69   int r_idx = 0;
     70   for (r_idx = 0; r_idx < 4; r_idx++) {
     71     const int x_idx = (r_idx & 1) * bw / 2;
     72     const int y_idx = (r_idx >> 1) * bh / 2;
     73     const int src_offset = y_idx * src_stride + x_idx;
     74     const int dst_offset = y_idx * dst_stride + x_idx;
     75     const unsigned int sub_var = cpi->fn_ptr[subsize].vf(
     76         src + src_offset, src_stride, dst + dst_offset, dst_stride, &sse);
     77     aom_clear_system_state();
     78     const float var_ratio = (1.0f + (float)sub_var) / (4.0f + (float)var);
     79     features[f_idx++] = var_ratio;
     80   }
     81 }
     82 
     83 void av1_simple_motion_search_based_split(
     84     AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
     85     BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
     86     int *partition_vert_allowed, int *do_rectangular_split,
     87     int *do_square_split) {
     88   const NN_CONFIG *nn_config = NULL;
     89   float split_only_thresh = 0.0f;
     90   if (bsize == BLOCK_128X128) {
     91     nn_config = &av1_simple_motion_search_based_split_nn_config_128;
     92     split_only_thresh = av1_simple_motion_search_based_split_thresh_128;
     93   } else if (bsize == BLOCK_64X64) {
     94     nn_config = &av1_simple_motion_search_based_split_nn_config_64;
     95     split_only_thresh = av1_simple_motion_search_based_split_thresh_64;
     96   } else if (bsize == BLOCK_32X32) {
     97     nn_config = &av1_simple_motion_search_based_split_nn_config_32;
     98     split_only_thresh = av1_simple_motion_search_based_split_thresh_32;
     99   } else if (bsize == BLOCK_16X16) {
    100     nn_config = &av1_simple_motion_search_based_split_nn_config_16;
    101     split_only_thresh = av1_simple_motion_search_based_split_thresh_16;
    102   } else if (bsize == BLOCK_8X8) {
    103     // Disable BLOCK_8X8 for now
    104 #if !CONFIG_DISABLE_FULL_PIXEL_SPLIT_8X8
    105     nn_config = &av1_simple_motion_search_based_split_nn_config_8;
    106     split_only_thresh = av1_simple_motion_search_based_split_thresh_8;
    107 #endif
    108   } else {
    109     assert(0 && "Unexpected block size in simple_motion_based_split");
    110   }
    111   if (nn_config) {
    112     float features[6] = { 0 };
    113     float score = 0;
    114     get_res_var_features(cpi, x, mi_row, mi_col, bsize, features);
    115     av1_nn_predict(features, nn_config, &score);
    116 
    117     if (score > split_only_thresh) {
    118       *partition_none_allowed = 0;
    119       *partition_horz_allowed = 0;
    120       *partition_vert_allowed = 0;
    121       *do_rectangular_split = 0;
    122     }
    123     if (cpi->sf.simple_motion_search_split_only >= 2) {
    124       if (score < -split_only_thresh) *do_square_split = 0;
    125       // For larger scores (>split_only_thresh), none and rectangular partitions
    126       // are skipped. As score reduces, possibility of split decreases. Hence
    127       // for near larger scores (.875 * split_only_thresh to split_only_thresh)
    128       // none partition is disabled, but rectangular partitions are evaluated
    129       // additionally.
    130       if (score > (split_only_thresh * 0.875)) *partition_none_allowed = 0;
    131     }
    132   }
    133 }
    134 
    135 // Given a list of ref frames in refs, performs simple_motion_search on each of
    136 // the refs and returns the ref with the smallest sse. Returns -1 if none of the
    137 // ref in the list is available. Also stores the best sse and var in best_sse,
    138 // best_var, respectively. If save_mv_code is -1, don't update mv_ref_fulls in
    139 // pc_tree. If save_mv_code is between 0 and 3, update mv_ref_fulls under
    140 // pc_tree->split[i]. If save_mv_code is 4, update mv_ref_fulls under pc_tree.
    141 static int simple_motion_search_get_best_ref(
    142     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
    143     int mi_col, BLOCK_SIZE bsize, const int *const refs, int num_refs,
    144     int use_subpixel, int save_mv_code, unsigned int *best_sse,
    145     unsigned int *best_var) {
    146   // TODO(chiyotsai (at) google.com): The calculation of variance currently uses
    147   // bsize, so we might take area outside of the image into account. We need to
    148   // modify the SIMD functions to fix this later.
    149   const AV1_COMMON *const cm = &cpi->common;
    150   int best_ref = -1;
    151 
    152   if (mi_col >= cm->mi_cols || mi_row >= cm->mi_rows) {
    153     // If the whole block is outside of the image, set the var and sse to 0.
    154     *best_var = 0;
    155     *best_sse = 0;
    156 
    157     return best_ref;
    158   }
    159 
    160   // Otherwise do loop through the reference frames and find the one with the
    161   // minimum SSE
    162   const MACROBLOCKD *xd = &x->e_mbd;
    163   const MV *mv_ref_fulls = pc_tree->mv_ref_fulls;
    164 
    165   const int num_planes = 1;
    166 
    167   *best_sse = INT_MAX;
    168 
    169   for (int ref_idx = 0; ref_idx < num_refs; ref_idx++) {
    170     const int ref = refs[ref_idx];
    171 
    172     if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref]) {
    173       unsigned int curr_sse = 0, curr_var = 0;
    174       av1_simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref,
    175                                mv_ref_fulls[ref], num_planes, use_subpixel);
    176       curr_var = cpi->fn_ptr[bsize].vf(
    177           x->plane[0].src.buf, x->plane[0].src.stride, xd->plane[0].dst.buf,
    178           xd->plane[0].dst.stride, &curr_sse);
    179       if (curr_sse < *best_sse) {
    180         *best_sse = curr_sse;
    181         *best_var = curr_var;
    182         best_ref = ref;
    183       }
    184 
    185       const int new_mv_row = x->best_mv.as_mv.row / 8;
    186       const int new_mv_col = x->best_mv.as_mv.col / 8;
    187       if (save_mv_code == 4) {
    188         pc_tree->mv_ref_fulls[ref].row = new_mv_row;
    189         pc_tree->mv_ref_fulls[ref].col = new_mv_col;
    190       } else if (save_mv_code >= 0 && save_mv_code < 4) {
    191         // Propagate the new motion vectors to a lower level
    192         pc_tree->split[save_mv_code]->mv_ref_fulls[ref].row = new_mv_row;
    193         pc_tree->split[save_mv_code]->mv_ref_fulls[ref].col = new_mv_col;
    194       } else {
    195         assert(save_mv_code == -1 &&
    196                "Unknown code in simple_motion_search_get_best_ref.");
    197       }
    198     }
    199   }
    200 
    201   return best_ref;
    202 }
    203 
    204 // Performs fullpixel simple_motion_search with LAST_FRAME and ALTREF_FRAME on
    205 // each subblock and extract the variance and sse of residues. Then store the
    206 // var and sse from each partition subblock to features. The DC qindex is also
    207 // stored in features.
    208 // Here features is assumed to be a length 19 array.
    209 // After this function is called, we will store the following to features:
    210 // features[0:17] = var and sse from subblocks
    211 // features[18] = DC q_index
    212 static void simple_motion_search_prune_part_features(
    213     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
    214     int mi_col, BLOCK_SIZE bsize, float *features) {
    215   // TODO(chiyotsai (at) google.com): Cache the result of the motion search from the
    216   // larger bsize.
    217   const int w_mi = mi_size_wide[bsize];
    218   const int h_mi = mi_size_high[bsize];
    219   int f_idx = 0;
    220   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
    221   assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
    222          cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
    223 
    224   // Setting up motion search
    225   const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
    226   const int num_refs = 2;
    227   const int use_subpixel = 1;
    228 
    229   unsigned int int_features[FEATURE_SIZE_SMS_PRUNE_PART - 1];
    230 
    231   // Doing whole block first to update the mv
    232   simple_motion_search_get_best_ref(
    233       cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
    234       4, &int_features[f_idx], &int_features[f_idx + 1]);
    235   f_idx += 2;
    236 
    237   // Split subblocks
    238   BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
    239   int r_idx = 0;
    240   for (r_idx = 0; r_idx < 4; r_idx++) {
    241     const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
    242     const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
    243 
    244     simple_motion_search_get_best_ref(
    245         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
    246         use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
    247     f_idx += 2;
    248   }
    249 
    250   // Horz subblocks
    251   subsize = get_partition_subsize(bsize, PARTITION_HORZ);
    252   for (r_idx = 0; r_idx < 2; r_idx++) {
    253     const int sub_mi_col = mi_col + 0;
    254     const int sub_mi_row = mi_row + r_idx * h_mi / 2;
    255 
    256     simple_motion_search_get_best_ref(
    257         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
    258         use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
    259 
    260     f_idx += 2;
    261   }
    262 
    263   // Vert subblock
    264   subsize = get_partition_subsize(bsize, PARTITION_VERT);
    265   for (r_idx = 0; r_idx < 2; r_idx++) {
    266     const int sub_mi_col = mi_col + r_idx * w_mi / 2;
    267     const int sub_mi_row = mi_row + 0;
    268 
    269     simple_motion_search_get_best_ref(
    270         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
    271         use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
    272 
    273     f_idx += 2;
    274   }
    275 
    276   aom_clear_system_state();
    277   for (int idx = 0; idx < f_idx; idx++) {
    278     features[idx] = logf(1.0f + (float)int_features[idx]);
    279   }
    280 
    281   const MACROBLOCKD *xd = &x->e_mbd;
    282   set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
    283 
    284   // Q_INDEX
    285   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
    286   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
    287 
    288   // Neighbor stuff
    289   const int has_above = !!xd->above_mbmi;
    290   const int has_left = !!xd->left_mbmi;
    291   const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
    292   const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
    293   features[f_idx++] = (float)has_above;
    294   features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
    295   features[f_idx++] = (float)mi_size_high_log2[above_bsize];
    296   features[f_idx++] = (float)has_left;
    297   features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
    298   features[f_idx++] = (float)mi_size_high_log2[left_bsize];
    299 
    300   assert(f_idx == FEATURE_SIZE_SMS_PRUNE_PART);
    301 }
    302 
    303 void av1_simple_motion_search_prune_part(
    304     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
    305     int mi_col, BLOCK_SIZE bsize, int *partition_none_allowed,
    306     int *partition_horz_allowed, int *partition_vert_allowed,
    307     int *do_square_split, int *do_rectangular_split, int *prune_horz,
    308     int *prune_vert, float *features, int *valid) {
    309   const AV1_COMMON *const cm = &cpi->common;
    310   // Get model parameters
    311   const NN_CONFIG *nn_config = NULL;
    312   const float *prune_thresh = NULL, *only_thresh = NULL;
    313   const float *ml_mean = NULL, *ml_std = NULL;
    314   float normalized_features[FEATURE_SIZE_SMS_PRUNE_PART] = { 0.0f };
    315 
    316   if (bsize == BLOCK_128X128) {
    317     nn_config = &av1_simple_motion_search_prune_part_nn_config_128;
    318     ml_mean = av1_simple_motion_search_prune_part_mean_128;
    319     ml_std = av1_simple_motion_search_prune_part_std_128;
    320     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_128;
    321     only_thresh = av1_simple_motion_search_prune_part_only_thresh_128;
    322   } else if (bsize == BLOCK_64X64) {
    323     nn_config = &av1_simple_motion_search_prune_part_nn_config_64;
    324     ml_mean = av1_simple_motion_search_prune_part_mean_64;
    325     ml_std = av1_simple_motion_search_prune_part_std_64;
    326     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_64;
    327     only_thresh = av1_simple_motion_search_prune_part_only_thresh_64;
    328   } else if (bsize == BLOCK_32X32) {
    329     nn_config = &av1_simple_motion_search_prune_part_nn_config_32;
    330     ml_mean = av1_simple_motion_search_prune_part_mean_32;
    331     ml_std = av1_simple_motion_search_prune_part_std_32;
    332     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_32;
    333     only_thresh = av1_simple_motion_search_prune_part_only_thresh_32;
    334   } else if (bsize == BLOCK_16X16) {
    335     nn_config = &av1_simple_motion_search_prune_part_nn_config_16;
    336     ml_mean = av1_simple_motion_search_prune_part_mean_16;
    337     ml_std = av1_simple_motion_search_prune_part_std_16;
    338     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_16;
    339     only_thresh = av1_simple_motion_search_prune_part_only_thresh_16;
    340   } else if (bsize == BLOCK_8X8) {
    341     nn_config = &av1_simple_motion_search_prune_part_nn_config_8;
    342     ml_mean = av1_simple_motion_search_prune_part_mean_8;
    343     ml_std = av1_simple_motion_search_prune_part_std_8;
    344     prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_8;
    345     only_thresh = av1_simple_motion_search_prune_part_only_thresh_8;
    346   } else {
    347     assert(0 && "Unexpected block size in simple_motion_prune_part");
    348   }
    349 
    350   // If there is no valid threshold, return immediately.
    351   if (!nn_config || (prune_thresh[PARTITION_HORZ] == 0.0f &&
    352                      prune_thresh[PARTITION_VERT] == 0.0f)) {
    353     return;
    354   }
    355   if (bsize < BLOCK_8X8) {
    356     return;
    357   }
    358 
    359   // Get features
    360   simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
    361                                            bsize, features);
    362   *valid = 1;
    363   for (int f_idx = 0; f_idx < FEATURE_SIZE_SMS_PRUNE_PART; f_idx++) {
    364     normalized_features[f_idx] =
    365         (features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
    366   }
    367 
    368   // Get probabilities
    369   float scores[EXT_PARTITION_TYPES] = { 0.0f },
    370         probs[EXT_PARTITION_TYPES] = { 0.0f };
    371   const int num_classes = (bsize == BLOCK_128X128 || bsize == BLOCK_8X8)
    372                               ? PARTITION_TYPES
    373                               : EXT_PARTITION_TYPES;
    374 
    375   av1_nn_predict(normalized_features, nn_config, scores);
    376   aom_clear_system_state();
    377 
    378   av1_nn_softmax(scores, probs, num_classes);
    379 
    380   // Determine if we should prune rectangular partitions.
    381   if (cpi->sf.simple_motion_search_prune_rect && !frame_is_intra_only(cm) &&
    382       (*partition_horz_allowed || *partition_vert_allowed) &&
    383       bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
    384     *prune_horz = probs[PARTITION_HORZ] <= prune_thresh[PARTITION_HORZ];
    385     *prune_vert = probs[PARTITION_VERT] <= prune_thresh[PARTITION_VERT];
    386   }
    387 
    388   // Silence compiler warnings
    389   (void)only_thresh;
    390   (void)partition_none_allowed;
    391   (void)do_square_split;
    392   (void)do_rectangular_split;
    393 }
    394 
    395 // Early terminates PARTITION_NONE using simple_motion_search features and the
    396 // rate, distortion, and rdcost of PARTITION_NONE. This is only called when:
    397 //  - The frame is a show frame
    398 //  - The frame is not intra only
    399 //  - The current bsize is > BLOCK_8X8
    400 //  - blk_row + blk_height/2 < total_rows and blk_col + blk_width/2 < total_cols
    401 void av1_simple_motion_search_early_term_none(
    402     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
    403     int mi_col, BLOCK_SIZE bsize, const RD_STATS *none_rdc,
    404     int *early_terminate, float *simple_motion_features,
    405     int *simple_motion_features_are_valid) {
    406   // TODO(chiyotsai (at) google.com): There are other features we can extract from
    407   // PARTITION_NONE. Play with this later.
    408   int f_idx = 0;
    409   if (!*simple_motion_features_are_valid) {
    410     simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
    411                                              bsize, simple_motion_features);
    412     *simple_motion_features_are_valid = 1;
    413   }
    414   f_idx = 25;
    415 
    416   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
    417   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
    418   simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
    419 
    420   assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
    421 
    422   const float *ml_mean = NULL;
    423   const float *ml_std = NULL;
    424   const float *ml_model = NULL;
    425 
    426   if (bsize == BLOCK_128X128) {
    427     ml_mean = av1_simple_motion_search_term_none_mean_128;
    428     ml_std = av1_simple_motion_search_term_none_std_128;
    429     ml_model = av1_simple_motion_search_term_none_model_128;
    430   } else if (bsize == BLOCK_64X64) {
    431     ml_mean = av1_simple_motion_search_term_none_mean_64;
    432     ml_std = av1_simple_motion_search_term_none_std_64;
    433     ml_model = av1_simple_motion_search_term_none_model_64;
    434   } else if (bsize == BLOCK_32X32) {
    435     ml_mean = av1_simple_motion_search_term_none_mean_32;
    436     ml_std = av1_simple_motion_search_term_none_std_32;
    437     ml_model = av1_simple_motion_search_term_none_model_32;
    438   } else if (bsize == BLOCK_16X16) {
    439     ml_mean = av1_simple_motion_search_term_none_mean_16;
    440     ml_std = av1_simple_motion_search_term_none_std_16;
    441     ml_model = av1_simple_motion_search_term_none_model_16;
    442   } else {
    443     assert(0 && "Unexpected block size in simple_motion_term_none");
    444   }
    445 
    446   if (ml_model) {
    447     float score = 0.0f;
    448     for (f_idx = 0; f_idx < FEATURE_SIZE_SMS_TERM_NONE; f_idx++) {
    449       score += ml_model[f_idx] *
    450                (simple_motion_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
    451     }
    452     score += ml_model[FEATURE_SIZE_SMS_TERM_NONE];
    453 
    454     if (score >= 0.0f) {
    455       *early_terminate = 1;
    456     }
    457   }
    458 }
    459 
    460 static void firstpass_simple_motion_search_features(
    461     AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
    462     int mi_col, BLOCK_SIZE bsize, float *features) {
    463   assert(mi_size_wide[bsize] == mi_size_high[bsize]);
    464   assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
    465          cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
    466 
    467   // Setting up motion search
    468   const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
    469   const int num_refs = 2;
    470   const int use_subpixel = 0;
    471 
    472   unsigned int int_features[10] = { 0 };
    473 
    474   int f_idx = 0;
    475   // Doing whole block first to update the mv
    476   simple_motion_search_get_best_ref(
    477       cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
    478       4, &int_features[f_idx], &int_features[f_idx + 1]);
    479   f_idx += 2;
    480 
    481   // Split subblocks
    482   const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
    483   const int w_mi = mi_size_wide[bsize];
    484   const int h_mi = mi_size_high[bsize];
    485   for (int r_idx = 0; r_idx < 4; r_idx++) {
    486     const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
    487     const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
    488 
    489     simple_motion_search_get_best_ref(
    490         cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
    491         use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
    492     f_idx += 2;
    493   }
    494 
    495   aom_clear_system_state();
    496   for (int idx = 0; idx < f_idx; idx++) {
    497     features[idx] = logf(1.0f + (float)int_features[idx]);
    498   }
    499 
    500   const MACROBLOCKD *xd = &x->e_mbd;
    501   set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
    502 
    503   // Q_INDEX
    504   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
    505   features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
    506 
    507   // Neighbor stuff
    508   const int has_above = !!xd->above_mbmi;
    509   const int has_left = !!xd->left_mbmi;
    510   const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
    511   const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
    512   features[f_idx++] = (float)has_above;
    513   features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
    514   features[f_idx++] = (float)mi_size_high_log2[above_bsize];
    515   features[f_idx++] = (float)has_left;
    516   features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
    517   features[f_idx++] = (float)mi_size_high_log2[left_bsize];
    518 }
    519 
    520 void av1_firstpass_simple_motion_search_early_term(AV1_COMP *const cpi,
    521                                                    MACROBLOCK *x,
    522                                                    PC_TREE *pc_tree, int mi_row,
    523                                                    int mi_col, BLOCK_SIZE bsize,
    524                                                    const RD_STATS *none_rdc,
    525                                                    int *do_square_split) {
    526   const NN_CONFIG *nn_config = NULL;
    527   float thresh = 0.0f;
    528   const float *ml_mean = NULL, *ml_std = NULL;
    529   if (bsize == BLOCK_32X32) {
    530     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_32;
    531     ml_mean = av1_fp_simple_motion_search_term_none_mean_32;
    532     ml_std = av1_fp_simple_motion_search_term_none_std_32;
    533     thresh = av1_fp_simple_motion_search_term_none_thresh_32;
    534   } else if (bsize == BLOCK_16X16) {
    535     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_16;
    536     ml_mean = av1_fp_simple_motion_search_term_none_mean_16;
    537     ml_std = av1_fp_simple_motion_search_term_none_std_16;
    538     thresh = av1_fp_simple_motion_search_term_none_thresh_16;
    539   } else if (bsize == BLOCK_8X8) {
    540     nn_config = &av1_fp_simple_motion_search_term_none_nn_config_8;
    541     ml_mean = av1_fp_simple_motion_search_term_none_mean_8;
    542     ml_std = av1_fp_simple_motion_search_term_none_std_8;
    543     thresh = av1_fp_simple_motion_search_term_none_thresh_8;
    544   } else {
    545     assert(0 &&
    546            "Unexpected bsize in firstpass_simple_motion_search_early_term");
    547     return;
    548   }
    549 
    550   float ml_features[FEATURE_SIZE_FP_SMS_TERM_NONE] = { 0.0f };
    551 
    552   firstpass_simple_motion_search_features(cpi, x, pc_tree, mi_row, mi_col,
    553                                           bsize, ml_features);
    554   int f_idx = 17;
    555 
    556   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
    557   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
    558   ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
    559 
    560   for (f_idx = 0; f_idx < 20; f_idx++) {
    561     ml_features[f_idx] = (ml_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
    562   }
    563 
    564   // Get probabilities
    565   float score = 0.0f;
    566 
    567   av1_nn_predict(ml_features, nn_config, &score);
    568   aom_clear_system_state();
    569 
    570   // Determine if we should prune square partitions.
    571   if (score < thresh) {
    572     *do_square_split = 0;
    573   }
    574 }
    575 
    576 void av1_get_max_min_partition_features(AV1_COMP *const cpi, MACROBLOCK *x,
    577                                         int mi_row, int mi_col,
    578                                         float *features) {
    579   AV1_COMMON *const cm = &cpi->common;
    580   MACROBLOCKD *xd = &x->e_mbd;
    581   const BLOCK_SIZE sb_size = cm->seq_params.sb_size;
    582 
    583   assert(sb_size == BLOCK_128X128);
    584 
    585   int f_idx = 0;
    586 
    587   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
    588   aom_clear_system_state();
    589   const float log_q_sq = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
    590 
    591   // Perform full-pixel single motion search in Y plane of 16x16 mbs in the sb
    592   float sum_mv_row_sq = 0;
    593   float sum_mv_row = 0;
    594   float min_abs_mv_row = FLT_MAX;
    595   float max_abs_mv_row = 0;
    596 
    597   float sum_mv_col_sq = 0;
    598   float sum_mv_col = 0;
    599   float min_abs_mv_col = FLT_MAX;
    600   float max_abs_mv_col = 0;
    601 
    602   float sum_log_sse_sq = 0;
    603   float sum_log_sse = 0;
    604   float min_log_sse = FLT_MAX;
    605   float max_log_sse = 0;
    606 
    607   const BLOCK_SIZE mb_size = BLOCK_16X16;
    608   const int mb_rows = block_size_high[sb_size] / block_size_high[mb_size];
    609   const int mb_cols = block_size_wide[sb_size] / block_size_wide[mb_size];
    610   const int mb_in_mi_size_high_log2 = mi_size_high_log2[mb_size];
    611   const int mb_in_mi_size_wide_log2 = mi_size_wide_log2[mb_size];
    612 
    613   for (int mb_row = 0; mb_row < mb_rows; mb_row++)
    614     for (int mb_col = 0; mb_col < mb_cols; mb_col++) {
    615       const int this_mi_row = mi_row + (mb_row << mb_in_mi_size_high_log2);
    616       const int this_mi_col = mi_col + (mb_col << mb_in_mi_size_wide_log2);
    617       unsigned int sse = 0;
    618       unsigned int var = 0;
    619       const MV ref_mv_full = { .row = 0, .col = 0 };
    620 
    621       av1_simple_motion_sse_var(cpi, x, this_mi_row, this_mi_col, mb_size,
    622                                 ref_mv_full, 0, &sse, &var);
    623 
    624       aom_clear_system_state();
    625       const float mv_row = (float)(x->best_mv.as_mv.row / 8);
    626       const float mv_col = (float)(x->best_mv.as_mv.col / 8);
    627       const float log_sse = logf(1.0f + (float)sse);
    628       const float abs_mv_row = fabsf(mv_row);
    629       const float abs_mv_col = fabsf(mv_col);
    630 
    631       sum_mv_row_sq += mv_row * mv_row;
    632       sum_mv_row += mv_row;
    633       sum_mv_col_sq += mv_col * mv_col;
    634       sum_mv_col += mv_col;
    635 
    636       if (abs_mv_row < min_abs_mv_row) min_abs_mv_row = abs_mv_row;
    637       if (abs_mv_row > max_abs_mv_row) max_abs_mv_row = abs_mv_row;
    638       if (abs_mv_col < min_abs_mv_col) min_abs_mv_col = abs_mv_col;
    639       if (abs_mv_col > max_abs_mv_col) max_abs_mv_col = abs_mv_col;
    640 
    641       sum_log_sse_sq += log_sse * log_sse;
    642       sum_log_sse += log_sse;
    643       if (log_sse < min_log_sse) min_log_sse = log_sse;
    644       if (log_sse > max_log_sse) max_log_sse = log_sse;
    645     }
    646   aom_clear_system_state();
    647   const float avg_mv_row = sum_mv_row / 64.0f;
    648   const float var_mv_row = sum_mv_row_sq / 64.0f - avg_mv_row * avg_mv_row;
    649 
    650   const float avg_mv_col = sum_mv_col / 64.0f;
    651   const float var_mv_col = sum_mv_col_sq / 64.0f - avg_mv_col * avg_mv_col;
    652 
    653   const float avg_log_sse = sum_log_sse / 64.0f;
    654   const float var_log_sse = sum_log_sse_sq / 64.0f - avg_log_sse * avg_log_sse;
    655 
    656   features[f_idx++] = avg_log_sse;
    657   features[f_idx++] = avg_mv_col;
    658   features[f_idx++] = avg_mv_row;
    659   features[f_idx++] = log_q_sq;
    660   features[f_idx++] = max_abs_mv_col;
    661   features[f_idx++] = max_abs_mv_row;
    662   features[f_idx++] = max_log_sse;
    663   features[f_idx++] = min_abs_mv_col;
    664   features[f_idx++] = min_abs_mv_row;
    665   features[f_idx++] = min_log_sse;
    666   features[f_idx++] = var_log_sse;
    667   features[f_idx++] = var_mv_col;
    668   features[f_idx++] = var_mv_row;
    669 
    670   assert(f_idx == FEATURE_SIZE_MAX_MIN_PART_PRED);
    671 }
    672 
    673 BLOCK_SIZE av1_predict_max_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
    674                                      const float *features) {
    675   float scores[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f },
    676         probs[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f };
    677   const NN_CONFIG *nn_config = &av1_max_part_pred_nn_config;
    678 
    679   assert(cpi->sf.auto_max_partition_based_on_simple_motion != NOT_IN_USE);
    680 
    681   aom_clear_system_state();
    682   av1_nn_predict(features, nn_config, scores);
    683   av1_nn_softmax(scores, probs, MAX_NUM_CLASSES_MAX_MIN_PART_PRED);
    684 
    685   int result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1;
    686   if (cpi->sf.auto_max_partition_based_on_simple_motion == DIRECT_PRED) {
    687     result = 0;
    688     float max_prob = probs[0];
    689     for (int i = 1; i < MAX_NUM_CLASSES_MAX_MIN_PART_PRED; ++i) {
    690       if (probs[i] > max_prob) {
    691         max_prob = probs[i];
    692         result = i;
    693       }
    694     }
    695   } else if (cpi->sf.auto_max_partition_based_on_simple_motion ==
    696              RELAXED_PRED) {
    697     for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
    698          --result) {
    699       if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
    700         probs[result] += probs[result + 1];
    701       }
    702       if (probs[result] > 0.2) break;
    703     }
    704   } else if (cpi->sf.auto_max_partition_based_on_simple_motion == ADAPT_PRED) {
    705     const BLOCK_SIZE sb_size = cpi->common.seq_params.sb_size;
    706     MACROBLOCKD *const xd = &x->e_mbd;
    707     // TODO(debargha): x->source_variance is unavailable at this point,
    708     // so compute. The redundant recomputation later can be removed.
    709     const unsigned int source_variance =
    710         is_cur_buf_hbd(xd)
    711             ? av1_high_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size,
    712                                                  xd->bd)
    713             : av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size);
    714     if (source_variance > 16) {
    715       const double thresh = source_variance < 128 ? 0.05 : 0.1;
    716       for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
    717            --result) {
    718         if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
    719           probs[result] += probs[result + 1];
    720         }
    721         if (probs[result] > thresh) break;
    722       }
    723     }
    724   }
    725 
    726   return (BLOCK_SIZE)((result + 2) * 3);
    727 }
    728