1 /* 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 // 12 // Implements core class for intelligibility enhancer. 13 // 14 // Details of the model and algorithm can be found in the original paper: 15 // http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788 16 // 17 18 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" 19 20 #include <math.h> 21 #include <stdlib.h> 22 #include <algorithm> 23 #include <numeric> 24 25 #include "webrtc/base/checks.h" 26 #include "webrtc/common_audio/include/audio_util.h" 27 #include "webrtc/common_audio/window_generator.h" 28 29 namespace webrtc { 30 31 namespace { 32 33 const size_t kErbResolution = 2; 34 const int kWindowSizeMs = 2; 35 const int kChunkSizeMs = 10; // Size provided by APM. 36 const float kClipFreq = 200.0f; 37 const float kConfigRho = 0.02f; // Default production and interpretation SNR. 38 const float kKbdAlpha = 1.5f; 39 const float kLambdaBot = -1.0f; // Extreme values in bisection 40 const float kLambdaTop = -10e-18f; // search for lamda. 41 42 } // namespace 43 44 using std::complex; 45 using std::max; 46 using std::min; 47 using VarianceType = intelligibility::VarianceArray::StepType; 48 49 IntelligibilityEnhancer::TransformCallback::TransformCallback( 50 IntelligibilityEnhancer* parent, 51 IntelligibilityEnhancer::AudioSource source) 52 : parent_(parent), source_(source) { 53 } 54 55 void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock( 56 const complex<float>* const* in_block, 57 size_t in_channels, 58 size_t frames, 59 size_t /* out_channels */, 60 complex<float>* const* out_block) { 61 RTC_DCHECK_EQ(parent_->freqs_, frames); 62 for (size_t i = 0; i < in_channels; ++i) { 63 parent_->DispatchAudio(source_, in_block[i], out_block[i]); 64 } 65 } 66 67 IntelligibilityEnhancer::IntelligibilityEnhancer() 68 : IntelligibilityEnhancer(IntelligibilityEnhancer::Config()) { 69 } 70 71 IntelligibilityEnhancer::IntelligibilityEnhancer(const Config& config) 72 : freqs_(RealFourier::ComplexLength( 73 RealFourier::FftOrder(config.sample_rate_hz * kWindowSizeMs / 1000))), 74 window_size_(static_cast<size_t>(1 << RealFourier::FftOrder(freqs_))), 75 chunk_length_( 76 static_cast<size_t>(config.sample_rate_hz * kChunkSizeMs / 1000)), 77 bank_size_(GetBankSize(config.sample_rate_hz, kErbResolution)), 78 sample_rate_hz_(config.sample_rate_hz), 79 erb_resolution_(kErbResolution), 80 num_capture_channels_(config.num_capture_channels), 81 num_render_channels_(config.num_render_channels), 82 analysis_rate_(config.analysis_rate), 83 active_(true), 84 clear_variance_(freqs_, 85 config.var_type, 86 config.var_window_size, 87 config.var_decay_rate), 88 noise_variance_(freqs_, 89 config.var_type, 90 config.var_window_size, 91 config.var_decay_rate), 92 filtered_clear_var_(new float[bank_size_]), 93 filtered_noise_var_(new float[bank_size_]), 94 filter_bank_(bank_size_), 95 center_freqs_(new float[bank_size_]), 96 rho_(new float[bank_size_]), 97 gains_eq_(new float[bank_size_]), 98 gain_applier_(freqs_, config.gain_change_limit), 99 temp_render_out_buffer_(chunk_length_, num_render_channels_), 100 temp_capture_out_buffer_(chunk_length_, num_capture_channels_), 101 kbd_window_(new float[window_size_]), 102 render_callback_(this, AudioSource::kRenderStream), 103 capture_callback_(this, AudioSource::kCaptureStream), 104 block_count_(0), 105 analysis_step_(0) { 106 RTC_DCHECK_LE(config.rho, 1.0f); 107 108 CreateErbBank(); 109 110 // Assumes all rho equal. 111 for (size_t i = 0; i < bank_size_; ++i) { 112 rho_[i] = config.rho * config.rho; 113 } 114 115 float freqs_khz = kClipFreq / 1000.0f; 116 size_t erb_index = static_cast<size_t>(ceilf( 117 11.17f * logf((freqs_khz + 0.312f) / (freqs_khz + 14.6575f)) + 43.0f)); 118 start_freq_ = std::max(static_cast<size_t>(1), erb_index * erb_resolution_); 119 120 WindowGenerator::KaiserBesselDerived(kKbdAlpha, window_size_, 121 kbd_window_.get()); 122 render_mangler_.reset(new LappedTransform( 123 num_render_channels_, num_render_channels_, chunk_length_, 124 kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_)); 125 capture_mangler_.reset(new LappedTransform( 126 num_capture_channels_, num_capture_channels_, chunk_length_, 127 kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_)); 128 } 129 130 void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio, 131 int sample_rate_hz, 132 size_t num_channels) { 133 RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz); 134 RTC_CHECK_EQ(num_render_channels_, num_channels); 135 136 if (active_) { 137 render_mangler_->ProcessChunk(audio, temp_render_out_buffer_.channels()); 138 } 139 140 if (active_) { 141 for (size_t i = 0; i < num_render_channels_; ++i) { 142 memcpy(audio[i], temp_render_out_buffer_.channels()[i], 143 chunk_length_ * sizeof(**audio)); 144 } 145 } 146 } 147 148 void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio, 149 int sample_rate_hz, 150 size_t num_channels) { 151 RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz); 152 RTC_CHECK_EQ(num_capture_channels_, num_channels); 153 154 capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels()); 155 } 156 157 void IntelligibilityEnhancer::DispatchAudio( 158 IntelligibilityEnhancer::AudioSource source, 159 const complex<float>* in_block, 160 complex<float>* out_block) { 161 switch (source) { 162 case kRenderStream: 163 ProcessClearBlock(in_block, out_block); 164 break; 165 case kCaptureStream: 166 ProcessNoiseBlock(in_block, out_block); 167 break; 168 } 169 } 170 171 void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block, 172 complex<float>* out_block) { 173 if (block_count_ < 2) { 174 memset(out_block, 0, freqs_ * sizeof(*out_block)); 175 ++block_count_; 176 return; 177 } 178 179 // TODO(ekm): Use VAD to |Step| and |AnalyzeClearBlock| only if necessary. 180 if (true) { 181 clear_variance_.Step(in_block, false); 182 if (block_count_ % analysis_rate_ == analysis_rate_ - 1) { 183 const float power_target = std::accumulate( 184 clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.f); 185 AnalyzeClearBlock(power_target); 186 ++analysis_step_; 187 } 188 ++block_count_; 189 } 190 191 if (active_) { 192 gain_applier_.Apply(in_block, out_block); 193 } 194 } 195 196 void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) { 197 FilterVariance(clear_variance_.variance(), filtered_clear_var_.get()); 198 FilterVariance(noise_variance_.variance(), filtered_noise_var_.get()); 199 200 SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get()); 201 const float power_top = 202 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); 203 SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get()); 204 const float power_bot = 205 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); 206 if (power_target >= power_bot && power_target <= power_top) { 207 SolveForLambda(power_target, power_bot, power_top); 208 UpdateErbGains(); 209 } // Else experiencing variance underflow, so do nothing. 210 } 211 212 void IntelligibilityEnhancer::SolveForLambda(float power_target, 213 float power_bot, 214 float power_top) { 215 const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values 216 const int kMaxIters = 100; // for these, based on experiments. 217 218 const float reciprocal_power_target = 1.f / power_target; 219 float lambda_bot = kLambdaBot; 220 float lambda_top = kLambdaTop; 221 float power_ratio = 2.0f; // Ratio of achieved power to target power. 222 int iters = 0; 223 while (std::fabs(power_ratio - 1.0f) > kConvergeThresh && 224 iters <= kMaxIters) { 225 const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f; 226 SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get()); 227 const float power = 228 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); 229 if (power < power_target) { 230 lambda_bot = lambda; 231 } else { 232 lambda_top = lambda; 233 } 234 power_ratio = std::fabs(power * reciprocal_power_target); 235 ++iters; 236 } 237 } 238 239 void IntelligibilityEnhancer::UpdateErbGains() { 240 // (ERB gain) = filterbank' * (freq gain) 241 float* gains = gain_applier_.target(); 242 for (size_t i = 0; i < freqs_; ++i) { 243 gains[i] = 0.0f; 244 for (size_t j = 0; j < bank_size_; ++j) { 245 gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]); 246 } 247 } 248 } 249 250 void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block, 251 complex<float>* /*out_block*/) { 252 noise_variance_.Step(in_block); 253 } 254 255 size_t IntelligibilityEnhancer::GetBankSize(int sample_rate, 256 size_t erb_resolution) { 257 float freq_limit = sample_rate / 2000.0f; 258 size_t erb_scale = static_cast<size_t>(ceilf( 259 11.17f * logf((freq_limit + 0.312f) / (freq_limit + 14.6575f)) + 43.0f)); 260 return erb_scale * erb_resolution; 261 } 262 263 void IntelligibilityEnhancer::CreateErbBank() { 264 size_t lf = 1, rf = 4; 265 266 for (size_t i = 0; i < bank_size_; ++i) { 267 float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_)); 268 center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp)); 269 center_freqs_[i] -= 14678.49f; 270 } 271 float last_center_freq = center_freqs_[bank_size_ - 1]; 272 for (size_t i = 0; i < bank_size_; ++i) { 273 center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq; 274 } 275 276 for (size_t i = 0; i < bank_size_; ++i) { 277 filter_bank_[i].resize(freqs_); 278 } 279 280 for (size_t i = 1; i <= bank_size_; ++i) { 281 size_t lll, ll, rr, rrr; 282 static const size_t kOne = 1; // Avoids repeated static_cast<>s below. 283 lll = static_cast<size_t>(round( 284 center_freqs_[max(kOne, i - lf) - 1] * freqs_ / 285 (0.5f * sample_rate_hz_))); 286 ll = static_cast<size_t>(round( 287 center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_))); 288 lll = min(freqs_, max(lll, kOne)) - 1; 289 ll = min(freqs_, max(ll, kOne)) - 1; 290 291 rrr = static_cast<size_t>(round( 292 center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ / 293 (0.5f * sample_rate_hz_))); 294 rr = static_cast<size_t>(round( 295 center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ / 296 (0.5f * sample_rate_hz_))); 297 rrr = min(freqs_, max(rrr, kOne)) - 1; 298 rr = min(freqs_, max(rr, kOne)) - 1; 299 300 float step, element; 301 302 step = 1.0f / (ll - lll); 303 element = 0.0f; 304 for (size_t j = lll; j <= ll; ++j) { 305 filter_bank_[i - 1][j] = element; 306 element += step; 307 } 308 step = 1.0f / (rrr - rr); 309 element = 1.0f; 310 for (size_t j = rr; j <= rrr; ++j) { 311 filter_bank_[i - 1][j] = element; 312 element -= step; 313 } 314 for (size_t j = ll; j <= rr; ++j) { 315 filter_bank_[i - 1][j] = 1.0f; 316 } 317 } 318 319 float sum; 320 for (size_t i = 0; i < freqs_; ++i) { 321 sum = 0.0f; 322 for (size_t j = 0; j < bank_size_; ++j) { 323 sum += filter_bank_[j][i]; 324 } 325 for (size_t j = 0; j < bank_size_; ++j) { 326 filter_bank_[j][i] /= sum; 327 } 328 } 329 } 330 331 void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda, 332 size_t start_freq, 333 float* sols) { 334 bool quadratic = (kConfigRho < 1.0f); 335 const float* var_x0 = filtered_clear_var_.get(); 336 const float* var_n0 = filtered_noise_var_.get(); 337 338 for (size_t n = 0; n < start_freq; ++n) { 339 sols[n] = 1.0f; 340 } 341 342 // Analytic solution for optimal gains. See paper for derivation. 343 for (size_t n = start_freq - 1; n < bank_size_; ++n) { 344 float alpha0, beta0, gamma0; 345 gamma0 = 0.5f * rho_[n] * var_x0[n] * var_n0[n] + 346 lambda * var_x0[n] * var_n0[n] * var_n0[n]; 347 beta0 = lambda * var_x0[n] * (2 - rho_[n]) * var_x0[n] * var_n0[n]; 348 if (quadratic) { 349 alpha0 = lambda * var_x0[n] * (1 - rho_[n]) * var_x0[n] * var_x0[n]; 350 sols[n] = 351 (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0); 352 } else { 353 sols[n] = -gamma0 / beta0; 354 } 355 sols[n] = fmax(0, sols[n]); 356 } 357 } 358 359 void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) { 360 RTC_DCHECK_GT(freqs_, 0u); 361 for (size_t i = 0; i < bank_size_; ++i) { 362 result[i] = DotProduct(&filter_bank_[i][0], var, freqs_); 363 } 364 } 365 366 float IntelligibilityEnhancer::DotProduct(const float* a, 367 const float* b, 368 size_t length) { 369 float ret = 0.0f; 370 371 for (size_t i = 0; i < length; ++i) { 372 ret = fmaf(a[i], b[i], ret); 373 } 374 return ret; 375 } 376 377 bool IntelligibilityEnhancer::active() const { 378 return active_; 379 } 380 381 } // namespace webrtc 382