1 /* 2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h" 12 13 #include <math.h> 14 #include <string.h> 15 #include <cmath> 16 #include <complex> 17 #include <deque> 18 #include <set> 19 20 #include "webrtc/base/scoped_ptr.h" 21 #include "webrtc/common_audio/fft4g.h" 22 #include "webrtc/common_audio/include/audio_util.h" 23 #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h" 24 #include "webrtc/modules/audio_processing/transient/common.h" 25 #include "webrtc/modules/audio_processing/transient/transient_detector.h" 26 #include "webrtc/modules/audio_processing/ns/windows_private.h" 27 #include "webrtc/system_wrappers/include/logging.h" 28 #include "webrtc/typedefs.h" 29 30 namespace webrtc { 31 32 static const float kMeanIIRCoefficient = 0.5f; 33 static const float kVoiceThreshold = 0.02f; 34 35 // TODO(aluebs): Check if these values work also for 48kHz. 36 static const size_t kMinVoiceBin = 3; 37 static const size_t kMaxVoiceBin = 60; 38 39 namespace { 40 41 float ComplexMagnitude(float a, float b) { 42 return std::abs(a) + std::abs(b); 43 } 44 45 } // namespace 46 47 TransientSuppressor::TransientSuppressor() 48 : data_length_(0), 49 detection_length_(0), 50 analysis_length_(0), 51 buffer_delay_(0), 52 complex_analysis_length_(0), 53 num_channels_(0), 54 window_(NULL), 55 detector_smoothed_(0.f), 56 keypress_counter_(0), 57 chunks_since_keypress_(0), 58 detection_enabled_(false), 59 suppression_enabled_(false), 60 use_hard_restoration_(false), 61 chunks_since_voice_change_(0), 62 seed_(182), 63 using_reference_(false) { 64 } 65 66 TransientSuppressor::~TransientSuppressor() {} 67 68 int TransientSuppressor::Initialize(int sample_rate_hz, 69 int detection_rate_hz, 70 int num_channels) { 71 switch (sample_rate_hz) { 72 case ts::kSampleRate8kHz: 73 analysis_length_ = 128u; 74 window_ = kBlocks80w128; 75 break; 76 case ts::kSampleRate16kHz: 77 analysis_length_ = 256u; 78 window_ = kBlocks160w256; 79 break; 80 case ts::kSampleRate32kHz: 81 analysis_length_ = 512u; 82 window_ = kBlocks320w512; 83 break; 84 case ts::kSampleRate48kHz: 85 analysis_length_ = 1024u; 86 window_ = kBlocks480w1024; 87 break; 88 default: 89 return -1; 90 } 91 if (detection_rate_hz != ts::kSampleRate8kHz && 92 detection_rate_hz != ts::kSampleRate16kHz && 93 detection_rate_hz != ts::kSampleRate32kHz && 94 detection_rate_hz != ts::kSampleRate48kHz) { 95 return -1; 96 } 97 if (num_channels <= 0) { 98 return -1; 99 } 100 101 detector_.reset(new TransientDetector(detection_rate_hz)); 102 data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000; 103 if (data_length_ > analysis_length_) { 104 assert(false); 105 return -1; 106 } 107 buffer_delay_ = analysis_length_ - data_length_; 108 109 complex_analysis_length_ = analysis_length_ / 2 + 1; 110 assert(complex_analysis_length_ >= kMaxVoiceBin); 111 num_channels_ = num_channels; 112 in_buffer_.reset(new float[analysis_length_ * num_channels_]); 113 memset(in_buffer_.get(), 114 0, 115 analysis_length_ * num_channels_ * sizeof(in_buffer_[0])); 116 detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000; 117 detection_buffer_.reset(new float[detection_length_]); 118 memset(detection_buffer_.get(), 119 0, 120 detection_length_ * sizeof(detection_buffer_[0])); 121 out_buffer_.reset(new float[analysis_length_ * num_channels_]); 122 memset(out_buffer_.get(), 123 0, 124 analysis_length_ * num_channels_ * sizeof(out_buffer_[0])); 125 // ip[0] must be zero to trigger initialization using rdft(). 126 size_t ip_length = 2 + sqrtf(analysis_length_); 127 ip_.reset(new size_t[ip_length]()); 128 memset(ip_.get(), 0, ip_length * sizeof(ip_[0])); 129 wfft_.reset(new float[complex_analysis_length_ - 1]); 130 memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0])); 131 spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]); 132 memset(spectral_mean_.get(), 133 0, 134 complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0])); 135 fft_buffer_.reset(new float[analysis_length_ + 2]); 136 memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0])); 137 magnitudes_.reset(new float[complex_analysis_length_]); 138 memset(magnitudes_.get(), 139 0, 140 complex_analysis_length_ * sizeof(magnitudes_[0])); 141 mean_factor_.reset(new float[complex_analysis_length_]); 142 143 static const float kFactorHeight = 10.f; 144 static const float kLowSlope = 1.f; 145 static const float kHighSlope = 0.3f; 146 for (size_t i = 0; i < complex_analysis_length_; ++i) { 147 mean_factor_[i] = 148 kFactorHeight / 149 (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) + 150 kFactorHeight / 151 (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i))); 152 } 153 detector_smoothed_ = 0.f; 154 keypress_counter_ = 0; 155 chunks_since_keypress_ = 0; 156 detection_enabled_ = false; 157 suppression_enabled_ = false; 158 use_hard_restoration_ = false; 159 chunks_since_voice_change_ = 0; 160 seed_ = 182; 161 using_reference_ = false; 162 return 0; 163 } 164 165 int TransientSuppressor::Suppress(float* data, 166 size_t data_length, 167 int num_channels, 168 const float* detection_data, 169 size_t detection_length, 170 const float* reference_data, 171 size_t reference_length, 172 float voice_probability, 173 bool key_pressed) { 174 if (!data || data_length != data_length_ || num_channels != num_channels_ || 175 detection_length != detection_length_ || voice_probability < 0 || 176 voice_probability > 1) { 177 return -1; 178 } 179 180 UpdateKeypress(key_pressed); 181 UpdateBuffers(data); 182 183 int result = 0; 184 if (detection_enabled_) { 185 UpdateRestoration(voice_probability); 186 187 if (!detection_data) { 188 // Use the input data of the first channel if special detection data is 189 // not supplied. 190 detection_data = &in_buffer_[buffer_delay_]; 191 } 192 193 float detector_result = detector_->Detect( 194 detection_data, detection_length, reference_data, reference_length); 195 if (detector_result < 0) { 196 return -1; 197 } 198 199 using_reference_ = detector_->using_reference(); 200 201 // |detector_smoothed_| follows the |detector_result| when this last one is 202 // increasing, but has an exponential decaying tail to be able to suppress 203 // the ringing of keyclicks. 204 float smooth_factor = using_reference_ ? 0.6 : 0.1; 205 detector_smoothed_ = detector_result >= detector_smoothed_ 206 ? detector_result 207 : smooth_factor * detector_smoothed_ + 208 (1 - smooth_factor) * detector_result; 209 210 for (int i = 0; i < num_channels_; ++i) { 211 Suppress(&in_buffer_[i * analysis_length_], 212 &spectral_mean_[i * complex_analysis_length_], 213 &out_buffer_[i * analysis_length_]); 214 } 215 } 216 217 // If the suppression isn't enabled, we use the in buffer to delay the signal 218 // appropriately. This also gives time for the out buffer to be refreshed with 219 // new data between detection and suppression getting enabled. 220 for (int i = 0; i < num_channels_; ++i) { 221 memcpy(&data[i * data_length_], 222 suppression_enabled_ ? &out_buffer_[i * analysis_length_] 223 : &in_buffer_[i * analysis_length_], 224 data_length_ * sizeof(*data)); 225 } 226 return result; 227 } 228 229 // This should only be called when detection is enabled. UpdateBuffers() must 230 // have been called. At return, |out_buffer_| will be filled with the 231 // processed output. 232 void TransientSuppressor::Suppress(float* in_ptr, 233 float* spectral_mean, 234 float* out_ptr) { 235 // Go to frequency domain. 236 for (size_t i = 0; i < analysis_length_; ++i) { 237 // TODO(aluebs): Rename windows 238 fft_buffer_[i] = in_ptr[i] * window_[i]; 239 } 240 241 WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get()); 242 243 // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end 244 // for convenience. 245 fft_buffer_[analysis_length_] = fft_buffer_[1]; 246 fft_buffer_[analysis_length_ + 1] = 0.f; 247 fft_buffer_[1] = 0.f; 248 249 for (size_t i = 0; i < complex_analysis_length_; ++i) { 250 magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2], 251 fft_buffer_[i * 2 + 1]); 252 } 253 // Restore audio if necessary. 254 if (suppression_enabled_) { 255 if (use_hard_restoration_) { 256 HardRestoration(spectral_mean); 257 } else { 258 SoftRestoration(spectral_mean); 259 } 260 } 261 262 // Update the spectral mean. 263 for (size_t i = 0; i < complex_analysis_length_; ++i) { 264 spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] + 265 kMeanIIRCoefficient * magnitudes_[i]; 266 } 267 268 // Back to time domain. 269 // Put R[n/2] back in fft_buffer_[1]. 270 fft_buffer_[1] = fft_buffer_[analysis_length_]; 271 272 WebRtc_rdft(analysis_length_, 273 -1, 274 fft_buffer_.get(), 275 ip_.get(), 276 wfft_.get()); 277 const float fft_scaling = 2.f / analysis_length_; 278 279 for (size_t i = 0; i < analysis_length_; ++i) { 280 out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling; 281 } 282 } 283 284 void TransientSuppressor::UpdateKeypress(bool key_pressed) { 285 const int kKeypressPenalty = 1000 / ts::kChunkSizeMs; 286 const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs; 287 const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs; // 4 seconds. 288 289 if (key_pressed) { 290 keypress_counter_ += kKeypressPenalty; 291 chunks_since_keypress_ = 0; 292 detection_enabled_ = true; 293 } 294 keypress_counter_ = std::max(0, keypress_counter_ - 1); 295 296 if (keypress_counter_ > kIsTypingThreshold) { 297 if (!suppression_enabled_) { 298 LOG(LS_INFO) << "[ts] Transient suppression is now enabled."; 299 } 300 suppression_enabled_ = true; 301 keypress_counter_ = 0; 302 } 303 304 if (detection_enabled_ && 305 ++chunks_since_keypress_ > kChunksUntilNotTyping) { 306 if (suppression_enabled_) { 307 LOG(LS_INFO) << "[ts] Transient suppression is now disabled."; 308 } 309 detection_enabled_ = false; 310 suppression_enabled_ = false; 311 keypress_counter_ = 0; 312 } 313 } 314 315 void TransientSuppressor::UpdateRestoration(float voice_probability) { 316 const int kHardRestorationOffsetDelay = 3; 317 const int kHardRestorationOnsetDelay = 80; 318 319 bool not_voiced = voice_probability < kVoiceThreshold; 320 321 if (not_voiced == use_hard_restoration_) { 322 chunks_since_voice_change_ = 0; 323 } else { 324 ++chunks_since_voice_change_; 325 326 if ((use_hard_restoration_ && 327 chunks_since_voice_change_ > kHardRestorationOffsetDelay) || 328 (!use_hard_restoration_ && 329 chunks_since_voice_change_ > kHardRestorationOnsetDelay)) { 330 use_hard_restoration_ = not_voiced; 331 chunks_since_voice_change_ = 0; 332 } 333 } 334 } 335 336 // Shift buffers to make way for new data. Must be called after 337 // |detection_enabled_| is updated by UpdateKeypress(). 338 void TransientSuppressor::UpdateBuffers(float* data) { 339 // TODO(aluebs): Change to ring buffer. 340 memmove(in_buffer_.get(), 341 &in_buffer_[data_length_], 342 (buffer_delay_ + (num_channels_ - 1) * analysis_length_) * 343 sizeof(in_buffer_[0])); 344 // Copy new chunk to buffer. 345 for (int i = 0; i < num_channels_; ++i) { 346 memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_], 347 &data[i * data_length_], 348 data_length_ * sizeof(*data)); 349 } 350 if (detection_enabled_) { 351 // Shift previous chunk in out buffer. 352 memmove(out_buffer_.get(), 353 &out_buffer_[data_length_], 354 (buffer_delay_ + (num_channels_ - 1) * analysis_length_) * 355 sizeof(out_buffer_[0])); 356 // Initialize new chunk in out buffer. 357 for (int i = 0; i < num_channels_; ++i) { 358 memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 359 0, 360 data_length_ * sizeof(out_buffer_[0])); 361 } 362 } 363 } 364 365 // Restores the unvoiced signal if a click is present. 366 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds 367 // the spectral mean. The attenuation depends on |detector_smoothed_|. 368 // If a restoration takes place, the |magnitudes_| are updated to the new value. 369 void TransientSuppressor::HardRestoration(float* spectral_mean) { 370 const float detector_result = 371 1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f); 372 // To restore, we get the peaks in the spectrum. If higher than the previous 373 // spectral mean we adjust them. 374 for (size_t i = 0; i < complex_analysis_length_; ++i) { 375 if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) { 376 // RandU() generates values on [0, int16::max()] 377 const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) / 378 std::numeric_limits<int16_t>::max(); 379 const float scaled_mean = detector_result * spectral_mean[i]; 380 381 fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] + 382 scaled_mean * cosf(phase); 383 fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] + 384 scaled_mean * sinf(phase); 385 magnitudes_[i] = magnitudes_[i] - 386 detector_result * (magnitudes_[i] - spectral_mean[i]); 387 } 388 } 389 } 390 391 // Restores the voiced signal if a click is present. 392 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds 393 // the spectral mean and that is lower than some function of the current block 394 // frequency mean. The attenuation depends on |detector_smoothed_|. 395 // If a restoration takes place, the |magnitudes_| are updated to the new value. 396 void TransientSuppressor::SoftRestoration(float* spectral_mean) { 397 // Get the spectral magnitude mean of the current block. 398 float block_frequency_mean = 0; 399 for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) { 400 block_frequency_mean += magnitudes_[i]; 401 } 402 block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin); 403 404 // To restore, we get the peaks in the spectrum. If higher than the 405 // previous spectral mean and lower than a factor of the block mean 406 // we adjust them. The factor is a double sigmoid that has a minimum in the 407 // voice frequency range (300Hz - 3kHz). 408 for (size_t i = 0; i < complex_analysis_length_; ++i) { 409 if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 && 410 (using_reference_ || 411 magnitudes_[i] < block_frequency_mean * mean_factor_[i])) { 412 const float new_magnitude = 413 magnitudes_[i] - 414 detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]); 415 const float magnitude_ratio = new_magnitude / magnitudes_[i]; 416 417 fft_buffer_[i * 2] *= magnitude_ratio; 418 fft_buffer_[i * 2 + 1] *= magnitude_ratio; 419 magnitudes_[i] = new_magnitude; 420 } 421 } 422 } 423 424 } // namespace webrtc 425