Home | History | Annotate | Download | only in encoder
      1 /*
      2  *  Copyright (c) 2010 The WebM project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <math.h>
     12 
     13 #include "vp9/common/vp9_common.h"
     14 #include "vp9/common/vp9_entropymode.h"
     15 
     16 #include "vp9/encoder/vp9_cost.h"
     17 #include "vp9/encoder/vp9_encodemv.h"
     18 
     19 #include "vpx_dsp/vpx_dsp_common.h"
     20 
     21 static struct vp9_token mv_joint_encodings[MV_JOINTS];
     22 static struct vp9_token mv_class_encodings[MV_CLASSES];
     23 static struct vp9_token mv_fp_encodings[MV_FP_SIZE];
     24 
     25 void vp9_entropy_mv_init(void) {
     26   vp9_tokens_from_tree(mv_joint_encodings, vp9_mv_joint_tree);
     27   vp9_tokens_from_tree(mv_class_encodings, vp9_mv_class_tree);
     28   vp9_tokens_from_tree(mv_fp_encodings, vp9_mv_fp_tree);
     29 }
     30 
     31 static void encode_mv_component(vpx_writer *w, int comp,
     32                                 const nmv_component *mvcomp, int usehp) {
     33   int offset;
     34   const int sign = comp < 0;
     35   const int mag = sign ? -comp : comp;
     36   const int mv_class = vp9_get_mv_class(mag - 1, &offset);
     37   const int d = offset >> 3;         // int mv data
     38   const int fr = (offset >> 1) & 3;  // fractional mv data
     39   const int hp = offset & 1;         // high precision mv data
     40 
     41   assert(comp != 0);
     42 
     43   // Sign
     44   vpx_write(w, sign, mvcomp->sign);
     45 
     46   // Class
     47   vp9_write_token(w, vp9_mv_class_tree, mvcomp->classes,
     48                   &mv_class_encodings[mv_class]);
     49 
     50   // Integer bits
     51   if (mv_class == MV_CLASS_0) {
     52     vpx_write(w, d, mvcomp->class0[0]);
     53   } else {
     54     int i;
     55     const int n = mv_class + CLASS0_BITS - 1;  // number of bits
     56     for (i = 0; i < n; ++i) vpx_write(w, (d >> i) & 1, mvcomp->bits[i]);
     57   }
     58 
     59   // Fractional bits
     60   vp9_write_token(w, vp9_mv_fp_tree,
     61                   mv_class == MV_CLASS_0 ? mvcomp->class0_fp[d] : mvcomp->fp,
     62                   &mv_fp_encodings[fr]);
     63 
     64   // High precision bit
     65   if (usehp)
     66     vpx_write(w, hp, mv_class == MV_CLASS_0 ? mvcomp->class0_hp : mvcomp->hp);
     67 }
     68 
     69 static void build_nmv_component_cost_table(int *mvcost,
     70                                            const nmv_component *const mvcomp,
     71                                            int usehp) {
     72   int sign_cost[2], class_cost[MV_CLASSES], class0_cost[CLASS0_SIZE];
     73   int bits_cost[MV_OFFSET_BITS][2];
     74   int class0_fp_cost[CLASS0_SIZE][MV_FP_SIZE], fp_cost[MV_FP_SIZE];
     75   int class0_hp_cost[2], hp_cost[2];
     76   int i;
     77   int c, o;
     78 
     79   sign_cost[0] = vp9_cost_zero(mvcomp->sign);
     80   sign_cost[1] = vp9_cost_one(mvcomp->sign);
     81   vp9_cost_tokens(class_cost, mvcomp->classes, vp9_mv_class_tree);
     82   vp9_cost_tokens(class0_cost, mvcomp->class0, vp9_mv_class0_tree);
     83   for (i = 0; i < MV_OFFSET_BITS; ++i) {
     84     bits_cost[i][0] = vp9_cost_zero(mvcomp->bits[i]);
     85     bits_cost[i][1] = vp9_cost_one(mvcomp->bits[i]);
     86   }
     87 
     88   for (i = 0; i < CLASS0_SIZE; ++i)
     89     vp9_cost_tokens(class0_fp_cost[i], mvcomp->class0_fp[i], vp9_mv_fp_tree);
     90   vp9_cost_tokens(fp_cost, mvcomp->fp, vp9_mv_fp_tree);
     91 
     92   // Always build the hp costs to avoid an uninitialized warning from gcc
     93   class0_hp_cost[0] = vp9_cost_zero(mvcomp->class0_hp);
     94   class0_hp_cost[1] = vp9_cost_one(mvcomp->class0_hp);
     95   hp_cost[0] = vp9_cost_zero(mvcomp->hp);
     96   hp_cost[1] = vp9_cost_one(mvcomp->hp);
     97 
     98   mvcost[0] = 0;
     99   // MV_CLASS_0
    100   for (o = 0; o < (CLASS0_SIZE << 3); ++o) {
    101     int d, e, f;
    102     int cost = class_cost[MV_CLASS_0];
    103     int v = o + 1;
    104     d = (o >> 3);     /* int mv data */
    105     f = (o >> 1) & 3; /* fractional pel mv data */
    106     cost += class0_cost[d];
    107     cost += class0_fp_cost[d][f];
    108     if (usehp) {
    109       e = (o & 1); /* high precision mv data */
    110       cost += class0_hp_cost[e];
    111     }
    112     mvcost[v] = cost + sign_cost[0];
    113     mvcost[-v] = cost + sign_cost[1];
    114   }
    115   for (c = MV_CLASS_1; c < MV_CLASSES; ++c) {
    116     int d;
    117     for (d = 0; d < (1 << c); ++d) {
    118       int f;
    119       int whole_cost = class_cost[c];
    120       int b = c + CLASS0_BITS - 1; /* number of bits */
    121       for (i = 0; i < b; ++i) whole_cost += bits_cost[i][((d >> i) & 1)];
    122       for (f = 0; f < 4; ++f) {
    123         int cost = whole_cost + fp_cost[f];
    124         int v = (CLASS0_SIZE << (c + 2)) + d * 8 + f * 2 /* + e */ + 1;
    125         if (usehp) {
    126           mvcost[v] = cost + hp_cost[0] + sign_cost[0];
    127           mvcost[-v] = cost + hp_cost[0] + sign_cost[1];
    128           if (v + 1 > MV_MAX) break;
    129           mvcost[v + 1] = cost + hp_cost[1] + sign_cost[0];
    130           mvcost[-v - 1] = cost + hp_cost[1] + sign_cost[1];
    131         } else {
    132           mvcost[v] = cost + sign_cost[0];
    133           mvcost[-v] = cost + sign_cost[1];
    134           if (v + 1 > MV_MAX) break;
    135           mvcost[v + 1] = cost + sign_cost[0];
    136           mvcost[-v - 1] = cost + sign_cost[1];
    137         }
    138       }
    139     }
    140   }
    141 }
    142 
    143 static int update_mv(vpx_writer *w, const unsigned int ct[2], vpx_prob *cur_p,
    144                      vpx_prob upd_p) {
    145   const vpx_prob new_p = get_binary_prob(ct[0], ct[1]) | 1;
    146   const int update = cost_branch256(ct, *cur_p) + vp9_cost_zero(upd_p) >
    147                      cost_branch256(ct, new_p) + vp9_cost_one(upd_p) +
    148                          (7 << VP9_PROB_COST_SHIFT);
    149   vpx_write(w, update, upd_p);
    150   if (update) {
    151     *cur_p = new_p;
    152     vpx_write_literal(w, new_p >> 1, 7);
    153   }
    154   return update;
    155 }
    156 
    157 static void write_mv_update(const vpx_tree_index *tree,
    158                             vpx_prob probs[/*n - 1*/],
    159                             const unsigned int counts[/*n - 1*/], int n,
    160                             vpx_writer *w) {
    161   int i;
    162   unsigned int branch_ct[32][2];
    163 
    164   // Assuming max number of probabilities <= 32
    165   assert(n <= 32);
    166 
    167   vp9_tree_probs_from_distribution(tree, branch_ct, counts);
    168   for (i = 0; i < n - 1; ++i)
    169     update_mv(w, branch_ct[i], &probs[i], MV_UPDATE_PROB);
    170 }
    171 
    172 void vp9_write_nmv_probs(VP9_COMMON *cm, int usehp, vpx_writer *w,
    173                          nmv_context_counts *const counts) {
    174   int i, j;
    175   nmv_context *const mvc = &cm->fc->nmvc;
    176 
    177   write_mv_update(vp9_mv_joint_tree, mvc->joints, counts->joints, MV_JOINTS, w);
    178 
    179   for (i = 0; i < 2; ++i) {
    180     nmv_component *comp = &mvc->comps[i];
    181     nmv_component_counts *comp_counts = &counts->comps[i];
    182 
    183     update_mv(w, comp_counts->sign, &comp->sign, MV_UPDATE_PROB);
    184     write_mv_update(vp9_mv_class_tree, comp->classes, comp_counts->classes,
    185                     MV_CLASSES, w);
    186     write_mv_update(vp9_mv_class0_tree, comp->class0, comp_counts->class0,
    187                     CLASS0_SIZE, w);
    188     for (j = 0; j < MV_OFFSET_BITS; ++j)
    189       update_mv(w, comp_counts->bits[j], &comp->bits[j], MV_UPDATE_PROB);
    190   }
    191 
    192   for (i = 0; i < 2; ++i) {
    193     for (j = 0; j < CLASS0_SIZE; ++j)
    194       write_mv_update(vp9_mv_fp_tree, mvc->comps[i].class0_fp[j],
    195                       counts->comps[i].class0_fp[j], MV_FP_SIZE, w);
    196 
    197     write_mv_update(vp9_mv_fp_tree, mvc->comps[i].fp, counts->comps[i].fp,
    198                     MV_FP_SIZE, w);
    199   }
    200 
    201   if (usehp) {
    202     for (i = 0; i < 2; ++i) {
    203       update_mv(w, counts->comps[i].class0_hp, &mvc->comps[i].class0_hp,
    204                 MV_UPDATE_PROB);
    205       update_mv(w, counts->comps[i].hp, &mvc->comps[i].hp, MV_UPDATE_PROB);
    206     }
    207   }
    208 }
    209 
    210 void vp9_encode_mv(VP9_COMP *cpi, vpx_writer *w, const MV *mv, const MV *ref,
    211                    const nmv_context *mvctx, int usehp,
    212                    unsigned int *const max_mv_magnitude) {
    213   const MV diff = { mv->row - ref->row, mv->col - ref->col };
    214   const MV_JOINT_TYPE j = vp9_get_mv_joint(&diff);
    215   usehp = usehp && use_mv_hp(ref);
    216 
    217   vp9_write_token(w, vp9_mv_joint_tree, mvctx->joints, &mv_joint_encodings[j]);
    218   if (mv_joint_vertical(j))
    219     encode_mv_component(w, diff.row, &mvctx->comps[0], usehp);
    220 
    221   if (mv_joint_horizontal(j))
    222     encode_mv_component(w, diff.col, &mvctx->comps[1], usehp);
    223 
    224   // If auto_mv_step_size is enabled then keep track of the largest
    225   // motion vector component used.
    226   if (cpi->sf.mv.auto_mv_step_size) {
    227     const unsigned int maxv = VPXMAX(abs(mv->row), abs(mv->col)) >> 3;
    228     *max_mv_magnitude = VPXMAX(maxv, *max_mv_magnitude);
    229   }
    230 }
    231 
    232 void vp9_build_nmv_cost_table(int *mvjoint, int *mvcost[2],
    233                               const nmv_context *ctx, int usehp) {
    234   vp9_cost_tokens(mvjoint, ctx->joints, vp9_mv_joint_tree);
    235   build_nmv_component_cost_table(mvcost[0], &ctx->comps[0], usehp);
    236   build_nmv_component_cost_table(mvcost[1], &ctx->comps[1], usehp);
    237 }
    238 
    239 static void inc_mvs(const MODE_INFO *mi, const MB_MODE_INFO_EXT *mbmi_ext,
    240                     const int_mv mvs[2], nmv_context_counts *counts) {
    241   int i;
    242 
    243   for (i = 0; i < 1 + has_second_ref(mi); ++i) {
    244     const MV *ref = &mbmi_ext->ref_mvs[mi->ref_frame[i]][0].as_mv;
    245     const MV diff = { mvs[i].as_mv.row - ref->row,
    246                       mvs[i].as_mv.col - ref->col };
    247     vp9_inc_mv(&diff, counts);
    248   }
    249 }
    250 
    251 void vp9_update_mv_count(ThreadData *td) {
    252   const MACROBLOCKD *xd = &td->mb.e_mbd;
    253   const MODE_INFO *mi = xd->mi[0];
    254   const MB_MODE_INFO_EXT *mbmi_ext = td->mb.mbmi_ext;
    255 
    256   if (mi->sb_type < BLOCK_8X8) {
    257     const int num_4x4_w = num_4x4_blocks_wide_lookup[mi->sb_type];
    258     const int num_4x4_h = num_4x4_blocks_high_lookup[mi->sb_type];
    259     int idx, idy;
    260 
    261     for (idy = 0; idy < 2; idy += num_4x4_h) {
    262       for (idx = 0; idx < 2; idx += num_4x4_w) {
    263         const int i = idy * 2 + idx;
    264         if (mi->bmi[i].as_mode == NEWMV)
    265           inc_mvs(mi, mbmi_ext, mi->bmi[i].as_mv, &td->counts->mv);
    266       }
    267     }
    268   } else {
    269     if (mi->mode == NEWMV) inc_mvs(mi, mbmi_ext, mi->mv, &td->counts->mv);
    270   }
    271 }
    272