1 /* 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include <math.h> 12 #include <stdio.h> 13 #include "webrtc/base/checks.h" 14 #include "webrtc/modules/audio_coding/neteq/tools/neteq_quality_test.h" 15 #include "webrtc/modules/audio_coding/neteq/tools/output_audio_file.h" 16 #include "webrtc/modules/audio_coding/neteq/tools/output_wav_file.h" 17 #include "webrtc/modules/audio_coding/neteq/tools/resample_input_audio_file.h" 18 #include "webrtc/test/testsupport/fileutils.h" 19 20 using std::string; 21 22 namespace webrtc { 23 namespace test { 24 25 const uint8_t kPayloadType = 95; 26 const int kOutputSizeMs = 10; 27 const int kInitSeed = 0x12345678; 28 const int kPacketLossTimeUnitMs = 10; 29 30 // Common validator for file names. 31 static bool ValidateFilename(const string& value, bool write) { 32 FILE* fid = write ? fopen(value.c_str(), "wb") : fopen(value.c_str(), "rb"); 33 if (fid == nullptr) 34 return false; 35 fclose(fid); 36 return true; 37 } 38 39 // Define switch for input file name. 40 static bool ValidateInFilename(const char* flagname, const string& value) { 41 if (!ValidateFilename(value, false)) { 42 printf("Invalid input filename."); 43 return false; 44 } 45 return true; 46 } 47 48 DEFINE_string( 49 in_filename, 50 ResourcePath("audio_coding/speech_mono_16kHz", "pcm"), 51 "Filename for input audio (specify sample rate with --input_sample_rate ," 52 "and channels with --channels)."); 53 54 static const bool in_filename_dummy = 55 RegisterFlagValidator(&FLAGS_in_filename, &ValidateInFilename); 56 57 // Define switch for sample rate. 58 static bool ValidateSampleRate(const char* flagname, int32_t value) { 59 if (value == 8000 || value == 16000 || value == 32000 || value == 48000) 60 return true; 61 printf("Invalid sample rate should be 8000, 16000, 32000 or 48000 Hz."); 62 return false; 63 } 64 65 DEFINE_int32(input_sample_rate, 16000, "Sample rate of input file in Hz."); 66 67 static const bool sample_rate_dummy = 68 RegisterFlagValidator(&FLAGS_input_sample_rate, &ValidateSampleRate); 69 70 // Define switch for channels. 71 static bool ValidateChannels(const char* flagname, int32_t value) { 72 if (value == 1) 73 return true; 74 printf("Invalid number of channels, current support only 1."); 75 return false; 76 } 77 78 DEFINE_int32(channels, 1, "Number of channels in input audio."); 79 80 static const bool channels_dummy = 81 RegisterFlagValidator(&FLAGS_channels, &ValidateChannels); 82 83 // Define switch for output file name. 84 static bool ValidateOutFilename(const char* flagname, const string& value) { 85 if (!ValidateFilename(value, true)) { 86 printf("Invalid output filename."); 87 return false; 88 } 89 return true; 90 } 91 92 DEFINE_string(out_filename, 93 OutputPath() + "neteq_quality_test_out.pcm", 94 "Name of output audio file."); 95 96 static const bool out_filename_dummy = 97 RegisterFlagValidator(&FLAGS_out_filename, &ValidateOutFilename); 98 99 // Define switch for packet loss rate. 100 static bool ValidatePacketLossRate(const char* /* flag_name */, int32_t value) { 101 if (value >= 0 && value <= 100) 102 return true; 103 printf("Invalid packet loss percentile, should be between 0 and 100."); 104 return false; 105 } 106 107 // Define switch for runtime. 108 static bool ValidateRuntime(const char* flagname, int32_t value) { 109 if (value > 0) 110 return true; 111 printf("Invalid runtime, should be greater than 0."); 112 return false; 113 } 114 115 DEFINE_int32(runtime_ms, 10000, "Simulated runtime (milliseconds)."); 116 117 static const bool runtime_dummy = 118 RegisterFlagValidator(&FLAGS_runtime_ms, &ValidateRuntime); 119 120 DEFINE_int32(packet_loss_rate, 10, "Percentile of packet loss."); 121 122 static const bool packet_loss_rate_dummy = 123 RegisterFlagValidator(&FLAGS_packet_loss_rate, &ValidatePacketLossRate); 124 125 // Define switch for random loss mode. 126 static bool ValidateRandomLossMode(const char* /* flag_name */, int32_t value) { 127 if (value >= 0 && value <= 2) 128 return true; 129 printf("Invalid random packet loss mode, should be between 0 and 2."); 130 return false; 131 } 132 133 DEFINE_int32(random_loss_mode, 1, 134 "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot loss."); 135 static const bool random_loss_mode_dummy = 136 RegisterFlagValidator(&FLAGS_random_loss_mode, &ValidateRandomLossMode); 137 138 // Define switch for burst length. 139 static bool ValidateBurstLength(const char* /* flag_name */, int32_t value) { 140 if (value >= kPacketLossTimeUnitMs) 141 return true; 142 printf("Invalid burst length, should be greater than %d ms.", 143 kPacketLossTimeUnitMs); 144 return false; 145 } 146 147 DEFINE_int32(burst_length, 30, 148 "Burst length in milliseconds, only valid for Gilbert Elliot loss."); 149 150 static const bool burst_length_dummy = 151 RegisterFlagValidator(&FLAGS_burst_length, &ValidateBurstLength); 152 153 // Define switch for drift factor. 154 static bool ValidateDriftFactor(const char* /* flag_name */, double value) { 155 if (value > -0.1) 156 return true; 157 printf("Invalid drift factor, should be greater than -0.1."); 158 return false; 159 } 160 161 DEFINE_double(drift_factor, 0.0, "Time drift factor."); 162 163 static const bool drift_factor_dummy = 164 RegisterFlagValidator(&FLAGS_drift_factor, &ValidateDriftFactor); 165 166 // ProbTrans00Solver() is to calculate the transition probability from no-loss 167 // state to itself in a modified Gilbert Elliot packet loss model. The result is 168 // to achieve the target packet loss rate |loss_rate|, when a packet is not 169 // lost only if all |units| drawings within the duration of the packet result in 170 // no-loss. 171 static double ProbTrans00Solver(int units, double loss_rate, 172 double prob_trans_10) { 173 if (units == 1) 174 return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10; 175 // 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 * 176 // prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10). 177 // There is a unique solution between 0.0 and 1.0, due to the monotonicity and 178 // an opposite sign at 0.0 and 1.0. 179 // For simplicity, we reformulate the equation as 180 // f(x) = x ^ (units - 1) + a x + b. 181 // Its derivative is 182 // f'(x) = (units - 1) x ^ (units - 2) + a. 183 // The derivative is strictly greater than 0 when x is between 0 and 1. 184 // We use Newton's method to solve the equation, iteration is 185 // x(k+1) = x(k) - f(x) / f'(x); 186 const double kPrecision = 0.001f; 187 const int kIterations = 100; 188 const double a = (1.0f - loss_rate) / prob_trans_10; 189 const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10); 190 double x = 0.0f; // Starting point; 191 double f = b; 192 double f_p; 193 int iter = 0; 194 while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) { 195 f_p = (units - 1.0f) * pow(x, units - 2) + a; 196 x -= f / f_p; 197 if (x > 1.0f) { 198 x = 1.0f; 199 } else if (x < 0.0f) { 200 x = 0.0f; 201 } 202 f = pow(x, units - 1) + a * x + b; 203 iter ++; 204 } 205 return x; 206 } 207 208 NetEqQualityTest::NetEqQualityTest(int block_duration_ms, 209 int in_sampling_khz, 210 int out_sampling_khz, 211 NetEqDecoder decoder_type) 212 : decoder_type_(decoder_type), 213 channels_(static_cast<size_t>(FLAGS_channels)), 214 decoded_time_ms_(0), 215 decodable_time_ms_(0), 216 drift_factor_(FLAGS_drift_factor), 217 packet_loss_rate_(FLAGS_packet_loss_rate), 218 block_duration_ms_(block_duration_ms), 219 in_sampling_khz_(in_sampling_khz), 220 out_sampling_khz_(out_sampling_khz), 221 in_size_samples_( 222 static_cast<size_t>(in_sampling_khz_ * block_duration_ms_)), 223 out_size_samples_(static_cast<size_t>(out_sampling_khz_ * kOutputSizeMs)), 224 payload_size_bytes_(0), 225 max_payload_bytes_(0), 226 in_file_(new ResampleInputAudioFile(FLAGS_in_filename, 227 FLAGS_input_sample_rate, 228 in_sampling_khz * 1000)), 229 rtp_generator_( 230 new RtpGenerator(in_sampling_khz_, 0, 0, decodable_time_ms_)), 231 total_payload_size_bytes_(0) { 232 const std::string out_filename = FLAGS_out_filename; 233 const std::string log_filename = out_filename + ".log"; 234 log_file_.open(log_filename.c_str(), std::ofstream::out); 235 RTC_CHECK(log_file_.is_open()); 236 237 if (out_filename.size() >= 4 && 238 out_filename.substr(out_filename.size() - 4) == ".wav") { 239 // Open a wav file. 240 output_.reset( 241 new webrtc::test::OutputWavFile(out_filename, 1000 * out_sampling_khz)); 242 } else { 243 // Open a pcm file. 244 output_.reset(new webrtc::test::OutputAudioFile(out_filename)); 245 } 246 247 NetEq::Config config; 248 config.sample_rate_hz = out_sampling_khz_ * 1000; 249 neteq_.reset(NetEq::Create(config)); 250 max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t); 251 in_data_.reset(new int16_t[in_size_samples_ * channels_]); 252 payload_.reset(new uint8_t[max_payload_bytes_]); 253 out_data_.reset(new int16_t[out_size_samples_ * channels_]); 254 } 255 256 NetEqQualityTest::~NetEqQualityTest() { 257 log_file_.close(); 258 } 259 260 bool NoLoss::Lost() { 261 return false; 262 } 263 264 UniformLoss::UniformLoss(double loss_rate) 265 : loss_rate_(loss_rate) { 266 } 267 268 bool UniformLoss::Lost() { 269 int drop_this = rand(); 270 return (drop_this < loss_rate_ * RAND_MAX); 271 } 272 273 GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01) 274 : prob_trans_11_(prob_trans_11), 275 prob_trans_01_(prob_trans_01), 276 lost_last_(false), 277 uniform_loss_model_(new UniformLoss(0)) { 278 } 279 280 bool GilbertElliotLoss::Lost() { 281 // Simulate bursty channel (Gilbert model). 282 // (1st order) Markov chain model with memory of the previous/last 283 // packet state (lost or received). 284 if (lost_last_) { 285 // Previous packet was not received. 286 uniform_loss_model_->set_loss_rate(prob_trans_11_); 287 return lost_last_ = uniform_loss_model_->Lost(); 288 } else { 289 uniform_loss_model_->set_loss_rate(prob_trans_01_); 290 return lost_last_ = uniform_loss_model_->Lost(); 291 } 292 } 293 294 void NetEqQualityTest::SetUp() { 295 ASSERT_EQ(0, 296 neteq_->RegisterPayloadType(decoder_type_, "noname", kPayloadType)); 297 rtp_generator_->set_drift_factor(drift_factor_); 298 299 int units = block_duration_ms_ / kPacketLossTimeUnitMs; 300 switch (FLAGS_random_loss_mode) { 301 case 1: { 302 // |unit_loss_rate| is the packet loss rate for each unit time interval 303 // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any 304 // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of 305 // a full packet duration is drawn with a loss, |unit_loss_rate| fulfills 306 // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) == 307 // 1 - packet_loss_rate. 308 double unit_loss_rate = (1.0f - pow(1.0f - 0.01f * packet_loss_rate_, 309 1.0f / units)); 310 loss_model_.reset(new UniformLoss(unit_loss_rate)); 311 break; 312 } 313 case 2: { 314 // |FLAGS_burst_length| should be integer times of kPacketLossTimeUnitMs. 315 ASSERT_EQ(0, FLAGS_burst_length % kPacketLossTimeUnitMs); 316 317 // We do not allow 100 percent packet loss in Gilbert Elliot model, which 318 // makes no sense. 319 ASSERT_GT(100, packet_loss_rate_); 320 321 // To guarantee the overall packet loss rate, transition probabilities 322 // need to satisfy: 323 // pi_0 * (1 - prob_trans_01_) ^ units + 324 // pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate 325 // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_) 326 // is the stationary state probability of no-loss 327 // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_) 328 // is the stationary state probability of loss 329 // After a derivation prob_trans_00 should satisfy: 330 // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 * 331 // prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10). 332 double loss_rate = 0.01f * packet_loss_rate_; 333 double prob_trans_10 = 1.0f * kPacketLossTimeUnitMs / FLAGS_burst_length; 334 double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10); 335 loss_model_.reset(new GilbertElliotLoss(1.0f - prob_trans_10, 336 1.0f - prob_trans_00)); 337 break; 338 } 339 default: { 340 loss_model_.reset(new NoLoss); 341 break; 342 } 343 } 344 345 // Make sure that the packet loss profile is same for all derived tests. 346 srand(kInitSeed); 347 } 348 349 std::ofstream& NetEqQualityTest::Log() { 350 return log_file_; 351 } 352 353 bool NetEqQualityTest::PacketLost() { 354 int cycles = block_duration_ms_ / kPacketLossTimeUnitMs; 355 356 // The loop is to make sure that codecs with different block lengths share the 357 // same packet loss profile. 358 bool lost = false; 359 for (int idx = 0; idx < cycles; idx ++) { 360 if (loss_model_->Lost()) { 361 // The packet will be lost if any of the drawings indicates a loss, but 362 // the loop has to go on to make sure that codecs with different block 363 // lengths keep the same pace. 364 lost = true; 365 } 366 } 367 return lost; 368 } 369 370 int NetEqQualityTest::Transmit() { 371 int packet_input_time_ms = 372 rtp_generator_->GetRtpHeader(kPayloadType, in_size_samples_, 373 &rtp_header_); 374 Log() << "Packet of size " 375 << payload_size_bytes_ 376 << " bytes, for frame at " 377 << packet_input_time_ms 378 << " ms "; 379 if (payload_size_bytes_ > 0) { 380 if (!PacketLost()) { 381 int ret = neteq_->InsertPacket( 382 rtp_header_, 383 rtc::ArrayView<const uint8_t>(payload_.get(), payload_size_bytes_), 384 packet_input_time_ms * in_sampling_khz_); 385 if (ret != NetEq::kOK) 386 return -1; 387 Log() << "was sent."; 388 } else { 389 Log() << "was lost."; 390 } 391 } 392 Log() << std::endl; 393 return packet_input_time_ms; 394 } 395 396 int NetEqQualityTest::DecodeBlock() { 397 size_t channels; 398 size_t samples; 399 int ret = neteq_->GetAudio(out_size_samples_ * channels_, &out_data_[0], 400 &samples, &channels, NULL); 401 402 if (ret != NetEq::kOK) { 403 return -1; 404 } else { 405 assert(channels == channels_); 406 assert(samples == static_cast<size_t>(kOutputSizeMs * out_sampling_khz_)); 407 RTC_CHECK(output_->WriteArray(out_data_.get(), samples * channels)); 408 return static_cast<int>(samples); 409 } 410 } 411 412 void NetEqQualityTest::Simulate() { 413 int audio_size_samples; 414 415 while (decoded_time_ms_ < FLAGS_runtime_ms) { 416 // Assume 10 packets in packets buffer. 417 while (decodable_time_ms_ - 10 * block_duration_ms_ < decoded_time_ms_) { 418 ASSERT_TRUE(in_file_->Read(in_size_samples_ * channels_, &in_data_[0])); 419 payload_size_bytes_ = EncodeBlock(&in_data_[0], 420 in_size_samples_, &payload_[0], 421 max_payload_bytes_); 422 total_payload_size_bytes_ += payload_size_bytes_; 423 decodable_time_ms_ = Transmit() + block_duration_ms_; 424 } 425 audio_size_samples = DecodeBlock(); 426 if (audio_size_samples > 0) { 427 decoded_time_ms_ += audio_size_samples / out_sampling_khz_; 428 } 429 } 430 Log() << "Average bit rate was " 431 << 8.0f * total_payload_size_bytes_ / FLAGS_runtime_ms 432 << " kbps" 433 << std::endl; 434 } 435 436 } // namespace test 437 } // namespace webrtc 438