Home | History | Annotate | Download | only in intelligibility
      1 /*
      2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 //
     12 //  Implements core class for intelligibility enhancer.
     13 //
     14 //  Details of the model and algorithm can be found in the original paper:
     15 //  http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788
     16 //
     17 
     18 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h"
     19 
     20 #include <math.h>
     21 #include <stdlib.h>
     22 #include <algorithm>
     23 #include <numeric>
     24 
     25 #include "webrtc/base/checks.h"
     26 #include "webrtc/common_audio/include/audio_util.h"
     27 #include "webrtc/common_audio/window_generator.h"
     28 
     29 namespace webrtc {
     30 
     31 namespace {
     32 
     33 const size_t kErbResolution = 2;
     34 const int kWindowSizeMs = 2;
     35 const int kChunkSizeMs = 10;  // Size provided by APM.
     36 const float kClipFreq = 200.0f;
     37 const float kConfigRho = 0.02f;  // Default production and interpretation SNR.
     38 const float kKbdAlpha = 1.5f;
     39 const float kLambdaBot = -1.0f;      // Extreme values in bisection
     40 const float kLambdaTop = -10e-18f;  // search for lamda.
     41 
     42 }  // namespace
     43 
     44 using std::complex;
     45 using std::max;
     46 using std::min;
     47 using VarianceType = intelligibility::VarianceArray::StepType;
     48 
     49 IntelligibilityEnhancer::TransformCallback::TransformCallback(
     50     IntelligibilityEnhancer* parent,
     51     IntelligibilityEnhancer::AudioSource source)
     52     : parent_(parent), source_(source) {
     53 }
     54 
     55 void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
     56     const complex<float>* const* in_block,
     57     size_t in_channels,
     58     size_t frames,
     59     size_t /* out_channels */,
     60     complex<float>* const* out_block) {
     61   RTC_DCHECK_EQ(parent_->freqs_, frames);
     62   for (size_t i = 0; i < in_channels; ++i) {
     63     parent_->DispatchAudio(source_, in_block[i], out_block[i]);
     64   }
     65 }
     66 
     67 IntelligibilityEnhancer::IntelligibilityEnhancer()
     68     : IntelligibilityEnhancer(IntelligibilityEnhancer::Config()) {
     69 }
     70 
     71 IntelligibilityEnhancer::IntelligibilityEnhancer(const Config& config)
     72     : freqs_(RealFourier::ComplexLength(
     73           RealFourier::FftOrder(config.sample_rate_hz * kWindowSizeMs / 1000))),
     74       window_size_(static_cast<size_t>(1 << RealFourier::FftOrder(freqs_))),
     75       chunk_length_(
     76           static_cast<size_t>(config.sample_rate_hz * kChunkSizeMs / 1000)),
     77       bank_size_(GetBankSize(config.sample_rate_hz, kErbResolution)),
     78       sample_rate_hz_(config.sample_rate_hz),
     79       erb_resolution_(kErbResolution),
     80       num_capture_channels_(config.num_capture_channels),
     81       num_render_channels_(config.num_render_channels),
     82       analysis_rate_(config.analysis_rate),
     83       active_(true),
     84       clear_variance_(freqs_,
     85                       config.var_type,
     86                       config.var_window_size,
     87                       config.var_decay_rate),
     88       noise_variance_(freqs_,
     89                       config.var_type,
     90                       config.var_window_size,
     91                       config.var_decay_rate),
     92       filtered_clear_var_(new float[bank_size_]),
     93       filtered_noise_var_(new float[bank_size_]),
     94       filter_bank_(bank_size_),
     95       center_freqs_(new float[bank_size_]),
     96       rho_(new float[bank_size_]),
     97       gains_eq_(new float[bank_size_]),
     98       gain_applier_(freqs_, config.gain_change_limit),
     99       temp_render_out_buffer_(chunk_length_, num_render_channels_),
    100       temp_capture_out_buffer_(chunk_length_, num_capture_channels_),
    101       kbd_window_(new float[window_size_]),
    102       render_callback_(this, AudioSource::kRenderStream),
    103       capture_callback_(this, AudioSource::kCaptureStream),
    104       block_count_(0),
    105       analysis_step_(0) {
    106   RTC_DCHECK_LE(config.rho, 1.0f);
    107 
    108   CreateErbBank();
    109 
    110   // Assumes all rho equal.
    111   for (size_t i = 0; i < bank_size_; ++i) {
    112     rho_[i] = config.rho * config.rho;
    113   }
    114 
    115   float freqs_khz = kClipFreq / 1000.0f;
    116   size_t erb_index = static_cast<size_t>(ceilf(
    117       11.17f * logf((freqs_khz + 0.312f) / (freqs_khz + 14.6575f)) + 43.0f));
    118   start_freq_ = std::max(static_cast<size_t>(1), erb_index * erb_resolution_);
    119 
    120   WindowGenerator::KaiserBesselDerived(kKbdAlpha, window_size_,
    121                                        kbd_window_.get());
    122   render_mangler_.reset(new LappedTransform(
    123       num_render_channels_, num_render_channels_, chunk_length_,
    124       kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_));
    125   capture_mangler_.reset(new LappedTransform(
    126       num_capture_channels_, num_capture_channels_, chunk_length_,
    127       kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_));
    128 }
    129 
    130 void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio,
    131                                                  int sample_rate_hz,
    132                                                  size_t num_channels) {
    133   RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz);
    134   RTC_CHECK_EQ(num_render_channels_, num_channels);
    135 
    136   if (active_) {
    137     render_mangler_->ProcessChunk(audio, temp_render_out_buffer_.channels());
    138   }
    139 
    140   if (active_) {
    141     for (size_t i = 0; i < num_render_channels_; ++i) {
    142       memcpy(audio[i], temp_render_out_buffer_.channels()[i],
    143              chunk_length_ * sizeof(**audio));
    144     }
    145   }
    146 }
    147 
    148 void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio,
    149                                                   int sample_rate_hz,
    150                                                   size_t num_channels) {
    151   RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz);
    152   RTC_CHECK_EQ(num_capture_channels_, num_channels);
    153 
    154   capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels());
    155 }
    156 
    157 void IntelligibilityEnhancer::DispatchAudio(
    158     IntelligibilityEnhancer::AudioSource source,
    159     const complex<float>* in_block,
    160     complex<float>* out_block) {
    161   switch (source) {
    162     case kRenderStream:
    163       ProcessClearBlock(in_block, out_block);
    164       break;
    165     case kCaptureStream:
    166       ProcessNoiseBlock(in_block, out_block);
    167       break;
    168   }
    169 }
    170 
    171 void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
    172                                                 complex<float>* out_block) {
    173   if (block_count_ < 2) {
    174     memset(out_block, 0, freqs_ * sizeof(*out_block));
    175     ++block_count_;
    176     return;
    177   }
    178 
    179   // TODO(ekm): Use VAD to |Step| and |AnalyzeClearBlock| only if necessary.
    180   if (true) {
    181     clear_variance_.Step(in_block, false);
    182     if (block_count_ % analysis_rate_ == analysis_rate_ - 1) {
    183       const float power_target = std::accumulate(
    184           clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.f);
    185       AnalyzeClearBlock(power_target);
    186       ++analysis_step_;
    187     }
    188     ++block_count_;
    189   }
    190 
    191   if (active_) {
    192     gain_applier_.Apply(in_block, out_block);
    193   }
    194 }
    195 
    196 void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
    197   FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
    198   FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
    199 
    200   SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get());
    201   const float power_top =
    202       DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
    203   SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get());
    204   const float power_bot =
    205       DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
    206   if (power_target >= power_bot && power_target <= power_top) {
    207     SolveForLambda(power_target, power_bot, power_top);
    208     UpdateErbGains();
    209   }  // Else experiencing variance underflow, so do nothing.
    210 }
    211 
    212 void IntelligibilityEnhancer::SolveForLambda(float power_target,
    213                                              float power_bot,
    214                                              float power_top) {
    215   const float kConvergeThresh = 0.001f;  // TODO(ekmeyerson): Find best values
    216   const int kMaxIters = 100;             // for these, based on experiments.
    217 
    218   const float reciprocal_power_target = 1.f / power_target;
    219   float lambda_bot = kLambdaBot;
    220   float lambda_top = kLambdaTop;
    221   float power_ratio = 2.0f;  // Ratio of achieved power to target power.
    222   int iters = 0;
    223   while (std::fabs(power_ratio - 1.0f) > kConvergeThresh &&
    224          iters <= kMaxIters) {
    225     const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f;
    226     SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get());
    227     const float power =
    228         DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
    229     if (power < power_target) {
    230       lambda_bot = lambda;
    231     } else {
    232       lambda_top = lambda;
    233     }
    234     power_ratio = std::fabs(power * reciprocal_power_target);
    235     ++iters;
    236   }
    237 }
    238 
    239 void IntelligibilityEnhancer::UpdateErbGains() {
    240   // (ERB gain) = filterbank' * (freq gain)
    241   float* gains = gain_applier_.target();
    242   for (size_t i = 0; i < freqs_; ++i) {
    243     gains[i] = 0.0f;
    244     for (size_t j = 0; j < bank_size_; ++j) {
    245       gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
    246     }
    247   }
    248 }
    249 
    250 void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block,
    251                                                 complex<float>* /*out_block*/) {
    252   noise_variance_.Step(in_block);
    253 }
    254 
    255 size_t IntelligibilityEnhancer::GetBankSize(int sample_rate,
    256                                             size_t erb_resolution) {
    257   float freq_limit = sample_rate / 2000.0f;
    258   size_t erb_scale = static_cast<size_t>(ceilf(
    259       11.17f * logf((freq_limit + 0.312f) / (freq_limit + 14.6575f)) + 43.0f));
    260   return erb_scale * erb_resolution;
    261 }
    262 
    263 void IntelligibilityEnhancer::CreateErbBank() {
    264   size_t lf = 1, rf = 4;
    265 
    266   for (size_t i = 0; i < bank_size_; ++i) {
    267     float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_));
    268     center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp));
    269     center_freqs_[i] -= 14678.49f;
    270   }
    271   float last_center_freq = center_freqs_[bank_size_ - 1];
    272   for (size_t i = 0; i < bank_size_; ++i) {
    273     center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq;
    274   }
    275 
    276   for (size_t i = 0; i < bank_size_; ++i) {
    277     filter_bank_[i].resize(freqs_);
    278   }
    279 
    280   for (size_t i = 1; i <= bank_size_; ++i) {
    281     size_t lll, ll, rr, rrr;
    282     static const size_t kOne = 1;  // Avoids repeated static_cast<>s below.
    283     lll = static_cast<size_t>(round(
    284         center_freqs_[max(kOne, i - lf) - 1] * freqs_ /
    285             (0.5f * sample_rate_hz_)));
    286     ll = static_cast<size_t>(round(
    287         center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_)));
    288     lll = min(freqs_, max(lll, kOne)) - 1;
    289     ll = min(freqs_, max(ll, kOne)) - 1;
    290 
    291     rrr = static_cast<size_t>(round(
    292         center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ /
    293             (0.5f * sample_rate_hz_)));
    294     rr = static_cast<size_t>(round(
    295         center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ /
    296             (0.5f * sample_rate_hz_)));
    297     rrr = min(freqs_, max(rrr, kOne)) - 1;
    298     rr = min(freqs_, max(rr, kOne)) - 1;
    299 
    300     float step, element;
    301 
    302     step = 1.0f / (ll - lll);
    303     element = 0.0f;
    304     for (size_t j = lll; j <= ll; ++j) {
    305       filter_bank_[i - 1][j] = element;
    306       element += step;
    307     }
    308     step = 1.0f / (rrr - rr);
    309     element = 1.0f;
    310     for (size_t j = rr; j <= rrr; ++j) {
    311       filter_bank_[i - 1][j] = element;
    312       element -= step;
    313     }
    314     for (size_t j = ll; j <= rr; ++j) {
    315       filter_bank_[i - 1][j] = 1.0f;
    316     }
    317   }
    318 
    319   float sum;
    320   for (size_t i = 0; i < freqs_; ++i) {
    321     sum = 0.0f;
    322     for (size_t j = 0; j < bank_size_; ++j) {
    323       sum += filter_bank_[j][i];
    324     }
    325     for (size_t j = 0; j < bank_size_; ++j) {
    326       filter_bank_[j][i] /= sum;
    327     }
    328   }
    329 }
    330 
    331 void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda,
    332                                                        size_t start_freq,
    333                                                        float* sols) {
    334   bool quadratic = (kConfigRho < 1.0f);
    335   const float* var_x0 = filtered_clear_var_.get();
    336   const float* var_n0 = filtered_noise_var_.get();
    337 
    338   for (size_t n = 0; n < start_freq; ++n) {
    339     sols[n] = 1.0f;
    340   }
    341 
    342   // Analytic solution for optimal gains. See paper for derivation.
    343   for (size_t n = start_freq - 1; n < bank_size_; ++n) {
    344     float alpha0, beta0, gamma0;
    345     gamma0 = 0.5f * rho_[n] * var_x0[n] * var_n0[n] +
    346              lambda * var_x0[n] * var_n0[n] * var_n0[n];
    347     beta0 = lambda * var_x0[n] * (2 - rho_[n]) * var_x0[n] * var_n0[n];
    348     if (quadratic) {
    349       alpha0 = lambda * var_x0[n] * (1 - rho_[n]) * var_x0[n] * var_x0[n];
    350       sols[n] =
    351           (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0);
    352     } else {
    353       sols[n] = -gamma0 / beta0;
    354     }
    355     sols[n] = fmax(0, sols[n]);
    356   }
    357 }
    358 
    359 void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
    360   RTC_DCHECK_GT(freqs_, 0u);
    361   for (size_t i = 0; i < bank_size_; ++i) {
    362     result[i] = DotProduct(&filter_bank_[i][0], var, freqs_);
    363   }
    364 }
    365 
    366 float IntelligibilityEnhancer::DotProduct(const float* a,
    367                                           const float* b,
    368                                           size_t length) {
    369   float ret = 0.0f;
    370 
    371   for (size_t i = 0; i < length; ++i) {
    372     ret = fmaf(a[i], b[i], ret);
    373   }
    374   return ret;
    375 }
    376 
    377 bool IntelligibilityEnhancer::active() const {
    378   return active_;
    379 }
    380 
    381 }  // namespace webrtc
    382