Home | History | Annotate | Download | only in tools
      1 /*
      2  *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <math.h>
     12 #include <stdio.h>
     13 #include "webrtc/base/checks.h"
     14 #include "webrtc/modules/audio_coding/neteq/tools/neteq_quality_test.h"
     15 #include "webrtc/modules/audio_coding/neteq/tools/output_audio_file.h"
     16 #include "webrtc/modules/audio_coding/neteq/tools/output_wav_file.h"
     17 #include "webrtc/modules/audio_coding/neteq/tools/resample_input_audio_file.h"
     18 #include "webrtc/test/testsupport/fileutils.h"
     19 
     20 using std::string;
     21 
     22 namespace webrtc {
     23 namespace test {
     24 
     25 const uint8_t kPayloadType = 95;
     26 const int kOutputSizeMs = 10;
     27 const int kInitSeed = 0x12345678;
     28 const int kPacketLossTimeUnitMs = 10;
     29 
     30 // Common validator for file names.
     31 static bool ValidateFilename(const string& value, bool write) {
     32   FILE* fid = write ? fopen(value.c_str(), "wb") : fopen(value.c_str(), "rb");
     33   if (fid == nullptr)
     34     return false;
     35   fclose(fid);
     36   return true;
     37 }
     38 
     39 // Define switch for input file name.
     40 static bool ValidateInFilename(const char* flagname, const string& value) {
     41   if (!ValidateFilename(value, false)) {
     42     printf("Invalid input filename.");
     43     return false;
     44   }
     45   return true;
     46 }
     47 
     48 DEFINE_string(
     49     in_filename,
     50     ResourcePath("audio_coding/speech_mono_16kHz", "pcm"),
     51     "Filename for input audio (specify sample rate with --input_sample_rate ,"
     52     "and channels with --channels).");
     53 
     54 static const bool in_filename_dummy =
     55     RegisterFlagValidator(&FLAGS_in_filename, &ValidateInFilename);
     56 
     57 // Define switch for sample rate.
     58 static bool ValidateSampleRate(const char* flagname, int32_t value) {
     59   if (value == 8000 || value == 16000 || value == 32000 || value == 48000)
     60     return true;
     61   printf("Invalid sample rate should be 8000, 16000, 32000 or 48000 Hz.");
     62   return false;
     63 }
     64 
     65 DEFINE_int32(input_sample_rate, 16000, "Sample rate of input file in Hz.");
     66 
     67 static const bool sample_rate_dummy =
     68     RegisterFlagValidator(&FLAGS_input_sample_rate, &ValidateSampleRate);
     69 
     70 // Define switch for channels.
     71 static bool ValidateChannels(const char* flagname, int32_t value) {
     72   if (value == 1)
     73     return true;
     74   printf("Invalid number of channels, current support only 1.");
     75   return false;
     76 }
     77 
     78 DEFINE_int32(channels, 1, "Number of channels in input audio.");
     79 
     80 static const bool channels_dummy =
     81     RegisterFlagValidator(&FLAGS_channels, &ValidateChannels);
     82 
     83 // Define switch for output file name.
     84 static bool ValidateOutFilename(const char* flagname, const string& value) {
     85   if (!ValidateFilename(value, true)) {
     86     printf("Invalid output filename.");
     87     return false;
     88   }
     89   return true;
     90 }
     91 
     92 DEFINE_string(out_filename,
     93               OutputPath() + "neteq_quality_test_out.pcm",
     94               "Name of output audio file.");
     95 
     96 static const bool out_filename_dummy =
     97     RegisterFlagValidator(&FLAGS_out_filename, &ValidateOutFilename);
     98 
     99 // Define switch for packet loss rate.
    100 static bool ValidatePacketLossRate(const char* /* flag_name */, int32_t value) {
    101   if (value >= 0 && value <= 100)
    102     return true;
    103   printf("Invalid packet loss percentile, should be between 0 and 100.");
    104   return false;
    105 }
    106 
    107 // Define switch for runtime.
    108 static bool ValidateRuntime(const char* flagname, int32_t value) {
    109   if (value > 0)
    110     return true;
    111   printf("Invalid runtime, should be greater than 0.");
    112   return false;
    113 }
    114 
    115 DEFINE_int32(runtime_ms, 10000, "Simulated runtime (milliseconds).");
    116 
    117 static const bool runtime_dummy =
    118     RegisterFlagValidator(&FLAGS_runtime_ms, &ValidateRuntime);
    119 
    120 DEFINE_int32(packet_loss_rate, 10, "Percentile of packet loss.");
    121 
    122 static const bool packet_loss_rate_dummy =
    123     RegisterFlagValidator(&FLAGS_packet_loss_rate, &ValidatePacketLossRate);
    124 
    125 // Define switch for random loss mode.
    126 static bool ValidateRandomLossMode(const char* /* flag_name */, int32_t value) {
    127   if (value >= 0 && value <= 2)
    128     return true;
    129   printf("Invalid random packet loss mode, should be between 0 and 2.");
    130   return false;
    131 }
    132 
    133 DEFINE_int32(random_loss_mode, 1,
    134     "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot loss.");
    135 static const bool random_loss_mode_dummy =
    136     RegisterFlagValidator(&FLAGS_random_loss_mode, &ValidateRandomLossMode);
    137 
    138 // Define switch for burst length.
    139 static bool ValidateBurstLength(const char* /* flag_name */, int32_t value) {
    140   if (value >= kPacketLossTimeUnitMs)
    141     return true;
    142   printf("Invalid burst length, should be greater than %d ms.",
    143          kPacketLossTimeUnitMs);
    144   return false;
    145 }
    146 
    147 DEFINE_int32(burst_length, 30,
    148     "Burst length in milliseconds, only valid for Gilbert Elliot loss.");
    149 
    150 static const bool burst_length_dummy =
    151     RegisterFlagValidator(&FLAGS_burst_length, &ValidateBurstLength);
    152 
    153 // Define switch for drift factor.
    154 static bool ValidateDriftFactor(const char* /* flag_name */, double value) {
    155   if (value > -0.1)
    156     return true;
    157   printf("Invalid drift factor, should be greater than -0.1.");
    158   return false;
    159 }
    160 
    161 DEFINE_double(drift_factor, 0.0, "Time drift factor.");
    162 
    163 static const bool drift_factor_dummy =
    164     RegisterFlagValidator(&FLAGS_drift_factor, &ValidateDriftFactor);
    165 
    166 // ProbTrans00Solver() is to calculate the transition probability from no-loss
    167 // state to itself in a modified Gilbert Elliot packet loss model. The result is
    168 // to achieve the target packet loss rate |loss_rate|, when a packet is not
    169 // lost only if all |units| drawings within the duration of the packet result in
    170 // no-loss.
    171 static double ProbTrans00Solver(int units, double loss_rate,
    172                                 double prob_trans_10) {
    173   if (units == 1)
    174     return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10;
    175 // 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 *
    176 //     prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10).
    177 // There is a unique solution between 0.0 and 1.0, due to the monotonicity and
    178 // an opposite sign at 0.0 and 1.0.
    179 // For simplicity, we reformulate the equation as
    180 //     f(x) = x ^ (units - 1) + a x + b.
    181 // Its derivative is
    182 //     f'(x) = (units - 1) x ^ (units - 2) + a.
    183 // The derivative is strictly greater than 0 when x is between 0 and 1.
    184 // We use Newton's method to solve the equation, iteration is
    185 //     x(k+1) = x(k) - f(x) / f'(x);
    186   const double kPrecision = 0.001f;
    187   const int kIterations = 100;
    188   const double a = (1.0f - loss_rate) / prob_trans_10;
    189   const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10);
    190   double x = 0.0f;  // Starting point;
    191   double f = b;
    192   double f_p;
    193   int iter = 0;
    194   while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) {
    195     f_p = (units - 1.0f) * pow(x, units - 2) + a;
    196     x -= f / f_p;
    197     if (x > 1.0f) {
    198       x = 1.0f;
    199     } else if (x < 0.0f) {
    200       x = 0.0f;
    201     }
    202     f = pow(x, units - 1) + a * x + b;
    203     iter ++;
    204   }
    205   return x;
    206 }
    207 
    208 NetEqQualityTest::NetEqQualityTest(int block_duration_ms,
    209                                    int in_sampling_khz,
    210                                    int out_sampling_khz,
    211                                    NetEqDecoder decoder_type)
    212     : decoder_type_(decoder_type),
    213       channels_(static_cast<size_t>(FLAGS_channels)),
    214       decoded_time_ms_(0),
    215       decodable_time_ms_(0),
    216       drift_factor_(FLAGS_drift_factor),
    217       packet_loss_rate_(FLAGS_packet_loss_rate),
    218       block_duration_ms_(block_duration_ms),
    219       in_sampling_khz_(in_sampling_khz),
    220       out_sampling_khz_(out_sampling_khz),
    221       in_size_samples_(
    222           static_cast<size_t>(in_sampling_khz_ * block_duration_ms_)),
    223       out_size_samples_(static_cast<size_t>(out_sampling_khz_ * kOutputSizeMs)),
    224       payload_size_bytes_(0),
    225       max_payload_bytes_(0),
    226       in_file_(new ResampleInputAudioFile(FLAGS_in_filename,
    227                                           FLAGS_input_sample_rate,
    228                                           in_sampling_khz * 1000)),
    229       rtp_generator_(
    230           new RtpGenerator(in_sampling_khz_, 0, 0, decodable_time_ms_)),
    231       total_payload_size_bytes_(0) {
    232   const std::string out_filename = FLAGS_out_filename;
    233   const std::string log_filename = out_filename + ".log";
    234   log_file_.open(log_filename.c_str(), std::ofstream::out);
    235   RTC_CHECK(log_file_.is_open());
    236 
    237   if (out_filename.size() >= 4 &&
    238       out_filename.substr(out_filename.size() - 4) == ".wav") {
    239     // Open a wav file.
    240     output_.reset(
    241         new webrtc::test::OutputWavFile(out_filename, 1000 * out_sampling_khz));
    242   } else {
    243     // Open a pcm file.
    244     output_.reset(new webrtc::test::OutputAudioFile(out_filename));
    245   }
    246 
    247   NetEq::Config config;
    248   config.sample_rate_hz = out_sampling_khz_ * 1000;
    249   neteq_.reset(NetEq::Create(config));
    250   max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t);
    251   in_data_.reset(new int16_t[in_size_samples_ * channels_]);
    252   payload_.reset(new uint8_t[max_payload_bytes_]);
    253   out_data_.reset(new int16_t[out_size_samples_ * channels_]);
    254 }
    255 
    256 NetEqQualityTest::~NetEqQualityTest() {
    257   log_file_.close();
    258 }
    259 
    260 bool NoLoss::Lost() {
    261   return false;
    262 }
    263 
    264 UniformLoss::UniformLoss(double loss_rate)
    265     : loss_rate_(loss_rate) {
    266 }
    267 
    268 bool UniformLoss::Lost() {
    269   int drop_this = rand();
    270   return (drop_this < loss_rate_ * RAND_MAX);
    271 }
    272 
    273 GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01)
    274     : prob_trans_11_(prob_trans_11),
    275       prob_trans_01_(prob_trans_01),
    276       lost_last_(false),
    277       uniform_loss_model_(new UniformLoss(0)) {
    278 }
    279 
    280 bool GilbertElliotLoss::Lost() {
    281   // Simulate bursty channel (Gilbert model).
    282   // (1st order) Markov chain model with memory of the previous/last
    283   // packet state (lost or received).
    284   if (lost_last_) {
    285     // Previous packet was not received.
    286     uniform_loss_model_->set_loss_rate(prob_trans_11_);
    287     return lost_last_ = uniform_loss_model_->Lost();
    288   } else {
    289     uniform_loss_model_->set_loss_rate(prob_trans_01_);
    290     return lost_last_ = uniform_loss_model_->Lost();
    291   }
    292 }
    293 
    294 void NetEqQualityTest::SetUp() {
    295   ASSERT_EQ(0,
    296             neteq_->RegisterPayloadType(decoder_type_, "noname", kPayloadType));
    297   rtp_generator_->set_drift_factor(drift_factor_);
    298 
    299   int units = block_duration_ms_ / kPacketLossTimeUnitMs;
    300   switch (FLAGS_random_loss_mode) {
    301     case 1: {
    302       // |unit_loss_rate| is the packet loss rate for each unit time interval
    303       // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any
    304       // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of
    305       // a full packet duration is drawn with a loss, |unit_loss_rate| fulfills
    306       // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) ==
    307       // 1 - packet_loss_rate.
    308       double unit_loss_rate = (1.0f - pow(1.0f - 0.01f * packet_loss_rate_,
    309           1.0f / units));
    310       loss_model_.reset(new UniformLoss(unit_loss_rate));
    311       break;
    312     }
    313     case 2: {
    314       // |FLAGS_burst_length| should be integer times of kPacketLossTimeUnitMs.
    315       ASSERT_EQ(0, FLAGS_burst_length % kPacketLossTimeUnitMs);
    316 
    317       // We do not allow 100 percent packet loss in Gilbert Elliot model, which
    318       // makes no sense.
    319       ASSERT_GT(100, packet_loss_rate_);
    320 
    321       // To guarantee the overall packet loss rate, transition probabilities
    322       // need to satisfy:
    323       // pi_0 * (1 - prob_trans_01_) ^ units +
    324       //     pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate
    325       // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_)
    326       //     is the stationary state probability of no-loss
    327       // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_)
    328       //     is the stationary state probability of loss
    329       // After a derivation prob_trans_00 should satisfy:
    330       // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 *
    331       //     prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10).
    332       double loss_rate = 0.01f * packet_loss_rate_;
    333       double prob_trans_10 = 1.0f * kPacketLossTimeUnitMs / FLAGS_burst_length;
    334       double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10);
    335       loss_model_.reset(new GilbertElliotLoss(1.0f - prob_trans_10,
    336                                               1.0f - prob_trans_00));
    337       break;
    338     }
    339     default: {
    340       loss_model_.reset(new NoLoss);
    341       break;
    342     }
    343   }
    344 
    345   // Make sure that the packet loss profile is same for all derived tests.
    346   srand(kInitSeed);
    347 }
    348 
    349 std::ofstream& NetEqQualityTest::Log() {
    350   return log_file_;
    351 }
    352 
    353 bool NetEqQualityTest::PacketLost() {
    354   int cycles = block_duration_ms_ / kPacketLossTimeUnitMs;
    355 
    356   // The loop is to make sure that codecs with different block lengths share the
    357   // same packet loss profile.
    358   bool lost = false;
    359   for (int idx = 0; idx < cycles; idx ++) {
    360     if (loss_model_->Lost()) {
    361       // The packet will be lost if any of the drawings indicates a loss, but
    362       // the loop has to go on to make sure that codecs with different block
    363       // lengths keep the same pace.
    364       lost = true;
    365     }
    366   }
    367   return lost;
    368 }
    369 
    370 int NetEqQualityTest::Transmit() {
    371   int packet_input_time_ms =
    372       rtp_generator_->GetRtpHeader(kPayloadType, in_size_samples_,
    373                                    &rtp_header_);
    374   Log() << "Packet of size "
    375         << payload_size_bytes_
    376         << " bytes, for frame at "
    377         << packet_input_time_ms
    378         << " ms ";
    379   if (payload_size_bytes_ > 0) {
    380     if (!PacketLost()) {
    381       int ret = neteq_->InsertPacket(
    382           rtp_header_,
    383           rtc::ArrayView<const uint8_t>(payload_.get(), payload_size_bytes_),
    384           packet_input_time_ms * in_sampling_khz_);
    385       if (ret != NetEq::kOK)
    386         return -1;
    387       Log() << "was sent.";
    388     } else {
    389       Log() << "was lost.";
    390     }
    391   }
    392   Log() << std::endl;
    393   return packet_input_time_ms;
    394 }
    395 
    396 int NetEqQualityTest::DecodeBlock() {
    397   size_t channels;
    398   size_t samples;
    399   int ret = neteq_->GetAudio(out_size_samples_ * channels_, &out_data_[0],
    400                              &samples, &channels, NULL);
    401 
    402   if (ret != NetEq::kOK) {
    403     return -1;
    404   } else {
    405     assert(channels == channels_);
    406     assert(samples == static_cast<size_t>(kOutputSizeMs * out_sampling_khz_));
    407     RTC_CHECK(output_->WriteArray(out_data_.get(), samples * channels));
    408     return static_cast<int>(samples);
    409   }
    410 }
    411 
    412 void NetEqQualityTest::Simulate() {
    413   int audio_size_samples;
    414 
    415   while (decoded_time_ms_ < FLAGS_runtime_ms) {
    416     // Assume 10 packets in packets buffer.
    417     while (decodable_time_ms_ - 10 * block_duration_ms_ < decoded_time_ms_) {
    418       ASSERT_TRUE(in_file_->Read(in_size_samples_ * channels_, &in_data_[0]));
    419       payload_size_bytes_ = EncodeBlock(&in_data_[0],
    420                                         in_size_samples_, &payload_[0],
    421                                         max_payload_bytes_);
    422       total_payload_size_bytes_ += payload_size_bytes_;
    423       decodable_time_ms_ = Transmit() + block_duration_ms_;
    424     }
    425     audio_size_samples = DecodeBlock();
    426     if (audio_size_samples > 0) {
    427       decoded_time_ms_ += audio_size_samples / out_sampling_khz_;
    428     }
    429   }
    430   Log() << "Average bit rate was "
    431         << 8.0f * total_payload_size_bytes_ / FLAGS_runtime_ms
    432         << " kbps"
    433         << std::endl;
    434 }
    435 
    436 }  // namespace test
    437 }  // namespace webrtc
    438