1 /* 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_ 12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_ 13 14 // MSVC++ requires this to be set before any other includes to get M_PI. 15 #define _USE_MATH_DEFINES 16 17 #include <math.h> 18 #include <vector> 19 20 #include "webrtc/common_audio/lapped_transform.h" 21 #include "webrtc/common_audio/channel_buffer.h" 22 #include "webrtc/modules/audio_processing/beamformer/beamformer.h" 23 #include "webrtc/modules/audio_processing/beamformer/complex_matrix.h" 24 #include "webrtc/system_wrappers/include/scoped_vector.h" 25 26 namespace webrtc { 27 28 // Enhances sound sources coming directly in front of a uniform linear array 29 // and suppresses sound sources coming from all other directions. Operates on 30 // multichannel signals and produces single-channel output. 31 // 32 // The implemented nonlinear postfilter algorithm taken from "A Robust Nonlinear 33 // Beamforming Postprocessor" by Bastiaan Kleijn. 34 class NonlinearBeamformer 35 : public Beamformer<float>, 36 public LappedTransform::Callback { 37 public: 38 static const float kHalfBeamWidthRadians; 39 40 explicit NonlinearBeamformer( 41 const std::vector<Point>& array_geometry, 42 SphericalPointf target_direction = 43 SphericalPointf(static_cast<float>(M_PI) / 2.f, 0.f, 1.f)); 44 45 // Sample rate corresponds to the lower band. 46 // Needs to be called before the NonlinearBeamformer can be used. 47 void Initialize(int chunk_size_ms, int sample_rate_hz) override; 48 49 // Process one time-domain chunk of audio. The audio is expected to be split 50 // into frequency bands inside the ChannelBuffer. The number of frames and 51 // channels must correspond to the constructor parameters. The same 52 // ChannelBuffer can be passed in as |input| and |output|. 53 void ProcessChunk(const ChannelBuffer<float>& input, 54 ChannelBuffer<float>* output) override; 55 56 void AimAt(const SphericalPointf& target_direction) override; 57 58 bool IsInBeam(const SphericalPointf& spherical_point) override; 59 60 // After processing each block |is_target_present_| is set to true if the 61 // target signal es present and to false otherwise. This methods can be called 62 // to know if the data is target signal or interference and process it 63 // accordingly. 64 bool is_target_present() override { return is_target_present_; } 65 66 protected: 67 // Process one frequency-domain block of audio. This is where the fun 68 // happens. Implements LappedTransform::Callback. 69 void ProcessAudioBlock(const complex<float>* const* input, 70 size_t num_input_channels, 71 size_t num_freq_bins, 72 size_t num_output_channels, 73 complex<float>* const* output) override; 74 75 private: 76 FRIEND_TEST_ALL_PREFIXES(NonlinearBeamformerTest, 77 InterfAnglesTakeAmbiguityIntoAccount); 78 79 typedef Matrix<float> MatrixF; 80 typedef ComplexMatrix<float> ComplexMatrixF; 81 typedef complex<float> complex_f; 82 83 void InitLowFrequencyCorrectionRanges(); 84 void InitHighFrequencyCorrectionRanges(); 85 void InitInterfAngles(); 86 void InitDelaySumMasks(); 87 void InitTargetCovMats(); 88 void InitDiffuseCovMats(); 89 void InitInterfCovMats(); 90 void NormalizeCovMats(); 91 92 // Calculates postfilter masks that minimize the mean squared error of our 93 // estimation of the desired signal. 94 float CalculatePostfilterMask(const ComplexMatrixF& interf_cov_mat, 95 float rpsiw, 96 float ratio_rxiw_rxim, 97 float rmxi_r); 98 99 // Prevents the postfilter masks from degenerating too quickly (a cause of 100 // musical noise). 101 void ApplyMaskTimeSmoothing(); 102 void ApplyMaskFrequencySmoothing(); 103 104 // The postfilter masks are unreliable at low frequencies. Calculates a better 105 // mask by averaging mid-low frequency values. 106 void ApplyLowFrequencyCorrection(); 107 108 // Postfilter masks are also unreliable at high frequencies. Average mid-high 109 // frequency masks to calculate a single mask per block which can be applied 110 // in the time-domain. Further, we average these block-masks over a chunk, 111 // resulting in one postfilter mask per audio chunk. This allows us to skip 112 // both transforming and blocking the high-frequency signal. 113 void ApplyHighFrequencyCorrection(); 114 115 // Compute the means needed for the above frequency correction. 116 float MaskRangeMean(size_t start_bin, size_t end_bin); 117 118 // Applies both sets of masks to |input| and store in |output|. 119 void ApplyMasks(const complex_f* const* input, complex_f* const* output); 120 121 void EstimateTargetPresence(); 122 123 static const size_t kFftSize = 256; 124 static const size_t kNumFreqBins = kFftSize / 2 + 1; 125 126 // Deals with the fft transform and blocking. 127 size_t chunk_length_; 128 rtc::scoped_ptr<LappedTransform> lapped_transform_; 129 float window_[kFftSize]; 130 131 // Parameters exposed to the user. 132 const size_t num_input_channels_; 133 int sample_rate_hz_; 134 135 const std::vector<Point> array_geometry_; 136 // The normal direction of the array if it has one and it is in the xy-plane. 137 const rtc::Optional<Point> array_normal_; 138 139 // Minimum spacing between microphone pairs. 140 const float min_mic_spacing_; 141 142 // Calculated based on user-input and constants in the .cc file. 143 size_t low_mean_start_bin_; 144 size_t low_mean_end_bin_; 145 size_t high_mean_start_bin_; 146 size_t high_mean_end_bin_; 147 148 // Quickly varying mask updated every block. 149 float new_mask_[kNumFreqBins]; 150 // Time smoothed mask. 151 float time_smooth_mask_[kNumFreqBins]; 152 // Time and frequency smoothed mask. 153 float final_mask_[kNumFreqBins]; 154 155 float target_angle_radians_; 156 // Angles of the interferer scenarios. 157 std::vector<float> interf_angles_radians_; 158 // The angle between the target and the interferer scenarios. 159 const float away_radians_; 160 161 // Array of length |kNumFreqBins|, Matrix of size |1| x |num_channels_|. 162 ComplexMatrixF delay_sum_masks_[kNumFreqBins]; 163 ComplexMatrixF normalized_delay_sum_masks_[kNumFreqBins]; 164 165 // Arrays of length |kNumFreqBins|, Matrix of size |num_input_channels_| x 166 // |num_input_channels_|. 167 ComplexMatrixF target_cov_mats_[kNumFreqBins]; 168 ComplexMatrixF uniform_cov_mat_[kNumFreqBins]; 169 // Array of length |kNumFreqBins|, Matrix of size |num_input_channels_| x 170 // |num_input_channels_|. ScopedVector has a size equal to the number of 171 // interferer scenarios. 172 ScopedVector<ComplexMatrixF> interf_cov_mats_[kNumFreqBins]; 173 174 // Of length |kNumFreqBins|. 175 float wave_numbers_[kNumFreqBins]; 176 177 // Preallocated for ProcessAudioBlock() 178 // Of length |kNumFreqBins|. 179 float rxiws_[kNumFreqBins]; 180 // The vector has a size equal to the number of interferer scenarios. 181 std::vector<float> rpsiws_[kNumFreqBins]; 182 183 // The microphone normalization factor. 184 ComplexMatrixF eig_m_; 185 186 // For processing the high-frequency input signal. 187 float high_pass_postfilter_mask_; 188 189 // True when the target signal is present. 190 bool is_target_present_; 191 // Number of blocks after which the data is considered interference if the 192 // mask does not pass |kMaskSignalThreshold|. 193 size_t hold_target_blocks_; 194 // Number of blocks since the last mask that passed |kMaskSignalThreshold|. 195 size_t interference_blocks_count_; 196 }; 197 198 } // namespace webrtc 199 200 #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_ 201