Home | History | Annotate | Download | only in transient
      1 /*
      2  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h"
     12 
     13 #include <math.h>
     14 #include <string.h>
     15 #include <cmath>
     16 #include <complex>
     17 #include <deque>
     18 #include <set>
     19 
     20 #include "webrtc/base/scoped_ptr.h"
     21 #include "webrtc/common_audio/fft4g.h"
     22 #include "webrtc/common_audio/include/audio_util.h"
     23 #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
     24 #include "webrtc/modules/audio_processing/transient/common.h"
     25 #include "webrtc/modules/audio_processing/transient/transient_detector.h"
     26 #include "webrtc/modules/audio_processing/ns/windows_private.h"
     27 #include "webrtc/system_wrappers/include/logging.h"
     28 #include "webrtc/typedefs.h"
     29 
     30 namespace webrtc {
     31 
     32 static const float kMeanIIRCoefficient = 0.5f;
     33 static const float kVoiceThreshold = 0.02f;
     34 
     35 // TODO(aluebs): Check if these values work also for 48kHz.
     36 static const size_t kMinVoiceBin = 3;
     37 static const size_t kMaxVoiceBin = 60;
     38 
     39 namespace {
     40 
     41 float ComplexMagnitude(float a, float b) {
     42   return std::abs(a) + std::abs(b);
     43 }
     44 
     45 }  // namespace
     46 
     47 TransientSuppressor::TransientSuppressor()
     48     : data_length_(0),
     49       detection_length_(0),
     50       analysis_length_(0),
     51       buffer_delay_(0),
     52       complex_analysis_length_(0),
     53       num_channels_(0),
     54       window_(NULL),
     55       detector_smoothed_(0.f),
     56       keypress_counter_(0),
     57       chunks_since_keypress_(0),
     58       detection_enabled_(false),
     59       suppression_enabled_(false),
     60       use_hard_restoration_(false),
     61       chunks_since_voice_change_(0),
     62       seed_(182),
     63       using_reference_(false) {
     64 }
     65 
     66 TransientSuppressor::~TransientSuppressor() {}
     67 
     68 int TransientSuppressor::Initialize(int sample_rate_hz,
     69                                     int detection_rate_hz,
     70                                     int num_channels) {
     71   switch (sample_rate_hz) {
     72     case ts::kSampleRate8kHz:
     73       analysis_length_ = 128u;
     74       window_ = kBlocks80w128;
     75       break;
     76     case ts::kSampleRate16kHz:
     77       analysis_length_ = 256u;
     78       window_ = kBlocks160w256;
     79       break;
     80     case ts::kSampleRate32kHz:
     81       analysis_length_ = 512u;
     82       window_ = kBlocks320w512;
     83       break;
     84     case ts::kSampleRate48kHz:
     85       analysis_length_ = 1024u;
     86       window_ = kBlocks480w1024;
     87       break;
     88     default:
     89       return -1;
     90   }
     91   if (detection_rate_hz != ts::kSampleRate8kHz &&
     92       detection_rate_hz != ts::kSampleRate16kHz &&
     93       detection_rate_hz != ts::kSampleRate32kHz &&
     94       detection_rate_hz != ts::kSampleRate48kHz) {
     95     return -1;
     96   }
     97   if (num_channels <= 0) {
     98     return -1;
     99   }
    100 
    101   detector_.reset(new TransientDetector(detection_rate_hz));
    102   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
    103   if (data_length_ > analysis_length_) {
    104     assert(false);
    105     return -1;
    106   }
    107   buffer_delay_ = analysis_length_ - data_length_;
    108 
    109   complex_analysis_length_ = analysis_length_ / 2 + 1;
    110   assert(complex_analysis_length_ >= kMaxVoiceBin);
    111   num_channels_ = num_channels;
    112   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
    113   memset(in_buffer_.get(),
    114          0,
    115          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
    116   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
    117   detection_buffer_.reset(new float[detection_length_]);
    118   memset(detection_buffer_.get(),
    119          0,
    120          detection_length_ * sizeof(detection_buffer_[0]));
    121   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
    122   memset(out_buffer_.get(),
    123          0,
    124          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
    125   // ip[0] must be zero to trigger initialization using rdft().
    126   size_t ip_length = 2 + sqrtf(analysis_length_);
    127   ip_.reset(new size_t[ip_length]());
    128   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
    129   wfft_.reset(new float[complex_analysis_length_ - 1]);
    130   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
    131   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
    132   memset(spectral_mean_.get(),
    133          0,
    134          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
    135   fft_buffer_.reset(new float[analysis_length_ + 2]);
    136   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
    137   magnitudes_.reset(new float[complex_analysis_length_]);
    138   memset(magnitudes_.get(),
    139          0,
    140          complex_analysis_length_ * sizeof(magnitudes_[0]));
    141   mean_factor_.reset(new float[complex_analysis_length_]);
    142 
    143   static const float kFactorHeight = 10.f;
    144   static const float kLowSlope = 1.f;
    145   static const float kHighSlope = 0.3f;
    146   for (size_t i = 0; i < complex_analysis_length_; ++i) {
    147     mean_factor_[i] =
    148         kFactorHeight /
    149             (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
    150         kFactorHeight /
    151             (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
    152   }
    153   detector_smoothed_ = 0.f;
    154   keypress_counter_ = 0;
    155   chunks_since_keypress_ = 0;
    156   detection_enabled_ = false;
    157   suppression_enabled_ = false;
    158   use_hard_restoration_ = false;
    159   chunks_since_voice_change_ = 0;
    160   seed_ = 182;
    161   using_reference_ = false;
    162   return 0;
    163 }
    164 
    165 int TransientSuppressor::Suppress(float* data,
    166                                   size_t data_length,
    167                                   int num_channels,
    168                                   const float* detection_data,
    169                                   size_t detection_length,
    170                                   const float* reference_data,
    171                                   size_t reference_length,
    172                                   float voice_probability,
    173                                   bool key_pressed) {
    174   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
    175       detection_length != detection_length_ || voice_probability < 0 ||
    176       voice_probability > 1) {
    177     return -1;
    178   }
    179 
    180   UpdateKeypress(key_pressed);
    181   UpdateBuffers(data);
    182 
    183   int result = 0;
    184   if (detection_enabled_) {
    185     UpdateRestoration(voice_probability);
    186 
    187     if (!detection_data) {
    188       // Use the input data  of the first channel if special detection data is
    189       // not supplied.
    190       detection_data = &in_buffer_[buffer_delay_];
    191     }
    192 
    193     float detector_result = detector_->Detect(
    194         detection_data, detection_length, reference_data, reference_length);
    195     if (detector_result < 0) {
    196       return -1;
    197     }
    198 
    199     using_reference_ = detector_->using_reference();
    200 
    201     // |detector_smoothed_| follows the |detector_result| when this last one is
    202     // increasing, but has an exponential decaying tail to be able to suppress
    203     // the ringing of keyclicks.
    204     float smooth_factor = using_reference_ ? 0.6 : 0.1;
    205     detector_smoothed_ = detector_result >= detector_smoothed_
    206                              ? detector_result
    207                              : smooth_factor * detector_smoothed_ +
    208                                    (1 - smooth_factor) * detector_result;
    209 
    210     for (int i = 0; i < num_channels_; ++i) {
    211       Suppress(&in_buffer_[i * analysis_length_],
    212                &spectral_mean_[i * complex_analysis_length_],
    213                &out_buffer_[i * analysis_length_]);
    214     }
    215   }
    216 
    217   // If the suppression isn't enabled, we use the in buffer to delay the signal
    218   // appropriately. This also gives time for the out buffer to be refreshed with
    219   // new data between detection and suppression getting enabled.
    220   for (int i = 0; i < num_channels_; ++i) {
    221     memcpy(&data[i * data_length_],
    222            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
    223                                 : &in_buffer_[i * analysis_length_],
    224            data_length_ * sizeof(*data));
    225   }
    226   return result;
    227 }
    228 
    229 // This should only be called when detection is enabled. UpdateBuffers() must
    230 // have been called. At return, |out_buffer_| will be filled with the
    231 // processed output.
    232 void TransientSuppressor::Suppress(float* in_ptr,
    233                                    float* spectral_mean,
    234                                    float* out_ptr) {
    235   // Go to frequency domain.
    236   for (size_t i = 0; i < analysis_length_; ++i) {
    237     // TODO(aluebs): Rename windows
    238     fft_buffer_[i] = in_ptr[i] * window_[i];
    239   }
    240 
    241   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
    242 
    243   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
    244   // for convenience.
    245   fft_buffer_[analysis_length_] = fft_buffer_[1];
    246   fft_buffer_[analysis_length_ + 1] = 0.f;
    247   fft_buffer_[1] = 0.f;
    248 
    249   for (size_t i = 0; i < complex_analysis_length_; ++i) {
    250     magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2],
    251                                       fft_buffer_[i * 2 + 1]);
    252   }
    253   // Restore audio if necessary.
    254   if (suppression_enabled_) {
    255     if (use_hard_restoration_) {
    256       HardRestoration(spectral_mean);
    257     } else {
    258       SoftRestoration(spectral_mean);
    259     }
    260   }
    261 
    262   // Update the spectral mean.
    263   for (size_t i = 0; i < complex_analysis_length_; ++i) {
    264     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
    265                        kMeanIIRCoefficient * magnitudes_[i];
    266   }
    267 
    268   // Back to time domain.
    269   // Put R[n/2] back in fft_buffer_[1].
    270   fft_buffer_[1] = fft_buffer_[analysis_length_];
    271 
    272   WebRtc_rdft(analysis_length_,
    273               -1,
    274               fft_buffer_.get(),
    275               ip_.get(),
    276               wfft_.get());
    277   const float fft_scaling = 2.f / analysis_length_;
    278 
    279   for (size_t i = 0; i < analysis_length_; ++i) {
    280     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
    281   }
    282 }
    283 
    284 void TransientSuppressor::UpdateKeypress(bool key_pressed) {
    285   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
    286   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
    287   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
    288 
    289   if (key_pressed) {
    290     keypress_counter_ += kKeypressPenalty;
    291     chunks_since_keypress_ = 0;
    292     detection_enabled_ = true;
    293   }
    294   keypress_counter_ = std::max(0, keypress_counter_ - 1);
    295 
    296   if (keypress_counter_ > kIsTypingThreshold) {
    297     if (!suppression_enabled_) {
    298       LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
    299     }
    300     suppression_enabled_ = true;
    301     keypress_counter_ = 0;
    302   }
    303 
    304   if (detection_enabled_ &&
    305       ++chunks_since_keypress_ > kChunksUntilNotTyping) {
    306     if (suppression_enabled_) {
    307       LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
    308     }
    309     detection_enabled_ = false;
    310     suppression_enabled_ = false;
    311     keypress_counter_ = 0;
    312   }
    313 }
    314 
    315 void TransientSuppressor::UpdateRestoration(float voice_probability) {
    316   const int kHardRestorationOffsetDelay = 3;
    317   const int kHardRestorationOnsetDelay = 80;
    318 
    319   bool not_voiced = voice_probability < kVoiceThreshold;
    320 
    321   if (not_voiced == use_hard_restoration_) {
    322     chunks_since_voice_change_ = 0;
    323   } else {
    324     ++chunks_since_voice_change_;
    325 
    326     if ((use_hard_restoration_ &&
    327          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
    328         (!use_hard_restoration_ &&
    329          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
    330       use_hard_restoration_ = not_voiced;
    331       chunks_since_voice_change_ = 0;
    332     }
    333   }
    334 }
    335 
    336 // Shift buffers to make way for new data. Must be called after
    337 // |detection_enabled_| is updated by UpdateKeypress().
    338 void TransientSuppressor::UpdateBuffers(float* data) {
    339   // TODO(aluebs): Change to ring buffer.
    340   memmove(in_buffer_.get(),
    341           &in_buffer_[data_length_],
    342           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
    343               sizeof(in_buffer_[0]));
    344   // Copy new chunk to buffer.
    345   for (int i = 0; i < num_channels_; ++i) {
    346     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
    347            &data[i * data_length_],
    348            data_length_ * sizeof(*data));
    349   }
    350   if (detection_enabled_) {
    351     // Shift previous chunk in out buffer.
    352     memmove(out_buffer_.get(),
    353             &out_buffer_[data_length_],
    354             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
    355                 sizeof(out_buffer_[0]));
    356     // Initialize new chunk in out buffer.
    357     for (int i = 0; i < num_channels_; ++i) {
    358       memset(&out_buffer_[buffer_delay_ + i * analysis_length_],
    359              0,
    360              data_length_ * sizeof(out_buffer_[0]));
    361     }
    362   }
    363 }
    364 
    365 // Restores the unvoiced signal if a click is present.
    366 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
    367 // the spectral mean. The attenuation depends on |detector_smoothed_|.
    368 // If a restoration takes place, the |magnitudes_| are updated to the new value.
    369 void TransientSuppressor::HardRestoration(float* spectral_mean) {
    370   const float detector_result =
    371       1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
    372   // To restore, we get the peaks in the spectrum. If higher than the previous
    373   // spectral mean we adjust them.
    374   for (size_t i = 0; i < complex_analysis_length_; ++i) {
    375     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
    376       // RandU() generates values on [0, int16::max()]
    377       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
    378           std::numeric_limits<int16_t>::max();
    379       const float scaled_mean = detector_result * spectral_mean[i];
    380 
    381       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
    382                            scaled_mean * cosf(phase);
    383       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
    384                                scaled_mean * sinf(phase);
    385       magnitudes_[i] = magnitudes_[i] -
    386                        detector_result * (magnitudes_[i] - spectral_mean[i]);
    387     }
    388   }
    389 }
    390 
    391 // Restores the voiced signal if a click is present.
    392 // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
    393 // the spectral mean and that is lower than some function of the current block
    394 // frequency mean. The attenuation depends on |detector_smoothed_|.
    395 // If a restoration takes place, the |magnitudes_| are updated to the new value.
    396 void TransientSuppressor::SoftRestoration(float* spectral_mean) {
    397   // Get the spectral magnitude mean of the current block.
    398   float block_frequency_mean = 0;
    399   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
    400     block_frequency_mean += magnitudes_[i];
    401   }
    402   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
    403 
    404   // To restore, we get the peaks in the spectrum. If higher than the
    405   // previous spectral mean and lower than a factor of the block mean
    406   // we adjust them. The factor is a double sigmoid that has a minimum in the
    407   // voice frequency range (300Hz - 3kHz).
    408   for (size_t i = 0; i < complex_analysis_length_; ++i) {
    409     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
    410         (using_reference_ ||
    411          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
    412       const float new_magnitude =
    413           magnitudes_[i] -
    414           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
    415       const float magnitude_ratio = new_magnitude / magnitudes_[i];
    416 
    417       fft_buffer_[i * 2] *= magnitude_ratio;
    418       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
    419       magnitudes_[i] = new_magnitude;
    420     }
    421   }
    422 }
    423 
    424 }  // namespace webrtc
    425