Home | History | Annotate | Download | only in beamformer
      1 /*
      2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
     12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
     13 
     14 // MSVC++ requires this to be set before any other includes to get M_PI.
     15 #define _USE_MATH_DEFINES
     16 
     17 #include <math.h>
     18 #include <vector>
     19 
     20 #include "webrtc/common_audio/lapped_transform.h"
     21 #include "webrtc/common_audio/channel_buffer.h"
     22 #include "webrtc/modules/audio_processing/beamformer/beamformer.h"
     23 #include "webrtc/modules/audio_processing/beamformer/complex_matrix.h"
     24 #include "webrtc/system_wrappers/include/scoped_vector.h"
     25 
     26 namespace webrtc {
     27 
     28 // Enhances sound sources coming directly in front of a uniform linear array
     29 // and suppresses sound sources coming from all other directions. Operates on
     30 // multichannel signals and produces single-channel output.
     31 //
     32 // The implemented nonlinear postfilter algorithm taken from "A Robust Nonlinear
     33 // Beamforming Postprocessor" by Bastiaan Kleijn.
     34 class NonlinearBeamformer
     35   : public Beamformer<float>,
     36     public LappedTransform::Callback {
     37  public:
     38   static const float kHalfBeamWidthRadians;
     39 
     40   explicit NonlinearBeamformer(
     41       const std::vector<Point>& array_geometry,
     42       SphericalPointf target_direction =
     43           SphericalPointf(static_cast<float>(M_PI) / 2.f, 0.f, 1.f));
     44 
     45   // Sample rate corresponds to the lower band.
     46   // Needs to be called before the NonlinearBeamformer can be used.
     47   void Initialize(int chunk_size_ms, int sample_rate_hz) override;
     48 
     49   // Process one time-domain chunk of audio. The audio is expected to be split
     50   // into frequency bands inside the ChannelBuffer. The number of frames and
     51   // channels must correspond to the constructor parameters. The same
     52   // ChannelBuffer can be passed in as |input| and |output|.
     53   void ProcessChunk(const ChannelBuffer<float>& input,
     54                     ChannelBuffer<float>* output) override;
     55 
     56   void AimAt(const SphericalPointf& target_direction) override;
     57 
     58   bool IsInBeam(const SphericalPointf& spherical_point) override;
     59 
     60   // After processing each block |is_target_present_| is set to true if the
     61   // target signal es present and to false otherwise. This methods can be called
     62   // to know if the data is target signal or interference and process it
     63   // accordingly.
     64   bool is_target_present() override { return is_target_present_; }
     65 
     66  protected:
     67   // Process one frequency-domain block of audio. This is where the fun
     68   // happens. Implements LappedTransform::Callback.
     69   void ProcessAudioBlock(const complex<float>* const* input,
     70                          size_t num_input_channels,
     71                          size_t num_freq_bins,
     72                          size_t num_output_channels,
     73                          complex<float>* const* output) override;
     74 
     75  private:
     76   FRIEND_TEST_ALL_PREFIXES(NonlinearBeamformerTest,
     77                            InterfAnglesTakeAmbiguityIntoAccount);
     78 
     79   typedef Matrix<float> MatrixF;
     80   typedef ComplexMatrix<float> ComplexMatrixF;
     81   typedef complex<float> complex_f;
     82 
     83   void InitLowFrequencyCorrectionRanges();
     84   void InitHighFrequencyCorrectionRanges();
     85   void InitInterfAngles();
     86   void InitDelaySumMasks();
     87   void InitTargetCovMats();
     88   void InitDiffuseCovMats();
     89   void InitInterfCovMats();
     90   void NormalizeCovMats();
     91 
     92   // Calculates postfilter masks that minimize the mean squared error of our
     93   // estimation of the desired signal.
     94   float CalculatePostfilterMask(const ComplexMatrixF& interf_cov_mat,
     95                                 float rpsiw,
     96                                 float ratio_rxiw_rxim,
     97                                 float rmxi_r);
     98 
     99   // Prevents the postfilter masks from degenerating too quickly (a cause of
    100   // musical noise).
    101   void ApplyMaskTimeSmoothing();
    102   void ApplyMaskFrequencySmoothing();
    103 
    104   // The postfilter masks are unreliable at low frequencies. Calculates a better
    105   // mask by averaging mid-low frequency values.
    106   void ApplyLowFrequencyCorrection();
    107 
    108   // Postfilter masks are also unreliable at high frequencies. Average mid-high
    109   // frequency masks to calculate a single mask per block which can be applied
    110   // in the time-domain. Further, we average these block-masks over a chunk,
    111   // resulting in one postfilter mask per audio chunk. This allows us to skip
    112   // both transforming and blocking the high-frequency signal.
    113   void ApplyHighFrequencyCorrection();
    114 
    115   // Compute the means needed for the above frequency correction.
    116   float MaskRangeMean(size_t start_bin, size_t end_bin);
    117 
    118   // Applies both sets of masks to |input| and store in |output|.
    119   void ApplyMasks(const complex_f* const* input, complex_f* const* output);
    120 
    121   void EstimateTargetPresence();
    122 
    123   static const size_t kFftSize = 256;
    124   static const size_t kNumFreqBins = kFftSize / 2 + 1;
    125 
    126   // Deals with the fft transform and blocking.
    127   size_t chunk_length_;
    128   rtc::scoped_ptr<LappedTransform> lapped_transform_;
    129   float window_[kFftSize];
    130 
    131   // Parameters exposed to the user.
    132   const size_t num_input_channels_;
    133   int sample_rate_hz_;
    134 
    135   const std::vector<Point> array_geometry_;
    136   // The normal direction of the array if it has one and it is in the xy-plane.
    137   const rtc::Optional<Point> array_normal_;
    138 
    139   // Minimum spacing between microphone pairs.
    140   const float min_mic_spacing_;
    141 
    142   // Calculated based on user-input and constants in the .cc file.
    143   size_t low_mean_start_bin_;
    144   size_t low_mean_end_bin_;
    145   size_t high_mean_start_bin_;
    146   size_t high_mean_end_bin_;
    147 
    148   // Quickly varying mask updated every block.
    149   float new_mask_[kNumFreqBins];
    150   // Time smoothed mask.
    151   float time_smooth_mask_[kNumFreqBins];
    152   // Time and frequency smoothed mask.
    153   float final_mask_[kNumFreqBins];
    154 
    155   float target_angle_radians_;
    156   // Angles of the interferer scenarios.
    157   std::vector<float> interf_angles_radians_;
    158   // The angle between the target and the interferer scenarios.
    159   const float away_radians_;
    160 
    161   // Array of length |kNumFreqBins|, Matrix of size |1| x |num_channels_|.
    162   ComplexMatrixF delay_sum_masks_[kNumFreqBins];
    163   ComplexMatrixF normalized_delay_sum_masks_[kNumFreqBins];
    164 
    165   // Arrays of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
    166   // |num_input_channels_|.
    167   ComplexMatrixF target_cov_mats_[kNumFreqBins];
    168   ComplexMatrixF uniform_cov_mat_[kNumFreqBins];
    169   // Array of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
    170   // |num_input_channels_|. ScopedVector has a size equal to the number of
    171   // interferer scenarios.
    172   ScopedVector<ComplexMatrixF> interf_cov_mats_[kNumFreqBins];
    173 
    174   // Of length |kNumFreqBins|.
    175   float wave_numbers_[kNumFreqBins];
    176 
    177   // Preallocated for ProcessAudioBlock()
    178   // Of length |kNumFreqBins|.
    179   float rxiws_[kNumFreqBins];
    180   // The vector has a size equal to the number of interferer scenarios.
    181   std::vector<float> rpsiws_[kNumFreqBins];
    182 
    183   // The microphone normalization factor.
    184   ComplexMatrixF eig_m_;
    185 
    186   // For processing the high-frequency input signal.
    187   float high_pass_postfilter_mask_;
    188 
    189   // True when the target signal is present.
    190   bool is_target_present_;
    191   // Number of blocks after which the data is considered interference if the
    192   // mask does not pass |kMaskSignalThreshold|.
    193   size_t hold_target_blocks_;
    194   // Number of blocks since the last mask that passed |kMaskSignalThreshold|.
    195   size_t interference_blocks_count_;
    196 };
    197 
    198 }  // namespace webrtc
    199 
    200 #endif  // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
    201