Home | History | Annotate | Download | only in brillo
      1 # Copyright (c) 2016 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """Server side audio utilities functions for Brillo."""
      6 
      7 import contextlib
      8 import logging
      9 import numpy
     10 import os
     11 import struct
     12 import subprocess
     13 import tempfile
     14 import wave
     15 
     16 from autotest_lib.client.common_lib import error
     17 
     18 
     19 _BITS_PER_BYTE=8
     20 
     21 # Thresholds used when comparing files.
     22 #
     23 # The frequency threshold used when comparing files. The frequency of the
     24 # recorded audio has to be within _FREQUENCY_THRESHOLD percent of the frequency
     25 # of the original audio.
     26 _FREQUENCY_THRESHOLD = 0.01
     27 # Noise threshold controls how much noise is allowed as a fraction of the
     28 # magnitude of the peak frequency after taking an FFT. The power of all the
     29 # other frequencies in the signal should be within _FFT_NOISE_THRESHOLD percent
     30 # of the power of the main frequency.
     31 _FFT_NOISE_THRESHOLD = 0.05
     32 
     33 # Command used to encode audio. If you want to test with something different,
     34 # this should be changed.
     35 _ENCODING_CMD = 'sox'
     36 
     37 
     38 def extract_wav_frames(wave_file):
     39     """Extract all frames from a WAV file.
     40 
     41     @param wave_file: A Wave_read object representing a WAV file opened for
     42                       reading.
     43 
     44     @return: A list containing the frames in the WAV file.
     45     """
     46     num_frames = wave_file.getnframes()
     47     sample_width = wave_file.getsampwidth()
     48     if sample_width == 1:
     49         fmt = '%iB'  # Read 1 byte.
     50     elif sample_width == 2:
     51         fmt = '%ih'  # Read 2 bytes.
     52     elif sample_width == 4:
     53         fmt = '%ii'  # Read 4 bytes.
     54     else:
     55         raise ValueError('Unsupported sample width')
     56     frames =  list(struct.unpack(fmt % num_frames * wave_file.getnchannels(),
     57                                  wave_file.readframes(num_frames)))
     58 
     59     # Since 8-bit PCM is unsigned with an offset of 128, we subtract the offset
     60     # to make it signed since the rest of the code assumes signed numbers.
     61     if sample_width == 1:
     62         frames = [val - 128 for val in frames]
     63 
     64     return frames
     65 
     66 
     67 def check_wav_file(filename, num_channels=None, sample_rate=None,
     68                    sample_width=None):
     69     """Checks a WAV file and returns its peak PCM values.
     70 
     71     @param filename: Input WAV file to analyze.
     72     @param num_channels: Number of channels to expect (None to not check).
     73     @param sample_rate: Sample rate to expect (None to not check).
     74     @param sample_width: Sample width to expect (None to not check).
     75 
     76     @return A list of the absolute maximum PCM values for each channel in the
     77             WAV file.
     78 
     79     @raise ValueError: Failed to process the WAV file or validate an attribute.
     80     """
     81     chk_file = None
     82     try:
     83         chk_file = wave.open(filename, 'r')
     84         if num_channels is not None and chk_file.getnchannels() != num_channels:
     85             raise ValueError('Expected %d channels but got %d instead.',
     86                              num_channels, chk_file.getnchannels())
     87         if sample_rate is not None and chk_file.getframerate() != sample_rate:
     88             raise ValueError('Expected sample rate %d but got %d instead.',
     89                              sample_rate, chk_file.getframerate())
     90         if sample_width is not None and chk_file.getsampwidth() != sample_width:
     91             raise ValueError('Expected sample width %d but got %d instead.',
     92                              sample_width, chk_file.getsampwidth())
     93         frames = extract_wav_frames(chk_file)
     94     except wave.Error as e:
     95         raise ValueError('Error processing WAV file: %s' % e)
     96     finally:
     97         if chk_file is not None:
     98             chk_file.close()
     99 
    100     peaks = []
    101     for i in range(chk_file.getnchannels()):
    102         peaks.append(max(map(abs, frames[i::chk_file.getnchannels()])))
    103     return peaks;
    104 
    105 
    106 def generate_sine_file(host, num_channels, sample_rate, sample_width,
    107                        duration_secs, sine_frequency, temp_dir,
    108                        file_format='wav'):
    109     """Generate a sine file and push it to the DUT.
    110 
    111     @param host: An object representing the DUT.
    112     @param num_channels: Number of channels to use.
    113     @param sample_rate: Sample rate to use for sine wave generation.
    114     @param sample_width: Sample width to use for sine wave generation.
    115     @param duration_secs: Duration in seconds to generate sine wave for.
    116     @param sine_frequency: Frequency to generate sine wave at.
    117     @param temp_dir: A temporary directory on the host.
    118     @param file_format: A string representing the encoding for the audio file.
    119 
    120     @return A tuple of the filename on the server and the DUT.
    121     """;
    122     _, local_filename = tempfile.mkstemp(
    123         prefix='sine-', suffix='.' + file_format, dir=temp_dir)
    124     if sample_width == 1:
    125         byte_format = '-e unsigned'
    126     else:
    127         byte_format = '-e signed'
    128     gen_file_cmd = ('sox -n -t wav -c %d %s -b %d -r %d %s synth %d sine %d '
    129                     'vol 0.9' % (num_channels, byte_format,
    130                                  sample_width * _BITS_PER_BYTE, sample_rate,
    131                                  local_filename, duration_secs, sine_frequency))
    132     logging.info('Command to generate sine wave: %s', gen_file_cmd)
    133     subprocess.call(gen_file_cmd, shell=True)
    134     if file_format != 'wav':
    135         # Convert the file to the appropriate format.
    136         logging.info('Converting file to %s', file_format)
    137         _, local_encoded_filename = tempfile.mkstemp(
    138                 prefix='sine-', suffix='.' + file_format, dir=temp_dir)
    139         cvt_file_cmd = '%s %s %s' % (_ENCODING_CMD, local_filename,
    140                                      local_encoded_filename)
    141         logging.info('Command to convert file: %s', cvt_file_cmd)
    142         subprocess.call(cvt_file_cmd, shell=True)
    143     else:
    144         local_encoded_filename = local_filename
    145     dut_tmp_dir = '/data'
    146     remote_filename = os.path.join(dut_tmp_dir, 'sine.' + file_format)
    147     logging.info('Send file to DUT.')
    148     # TODO(ralphnathan): Find a better place to put this file once the SELinux
    149     # issues are resolved.
    150     logging.info('remote_filename %s', remote_filename)
    151     host.send_file(local_encoded_filename, remote_filename)
    152     return local_filename, remote_filename
    153 
    154 
    155 def _is_outside_frequency_threshold(freq_reference, freq_rec):
    156     """Compares the frequency of the recorded audio with the reference audio.
    157 
    158     This function checks to see if the frequencies corresponding to the peak
    159     FFT values are similiar meaning that the dominant frequency in the audio
    160     signal is the same for the recorded audio as that in the audio played.
    161 
    162     @param req_reference: The dominant frequency in the reference audio file.
    163     @param freq_rec: The dominant frequency in the recorded audio file.
    164 
    165     @return: True is freq_rec is with _FREQUENCY_THRESHOLD percent of
    166               freq_reference.
    167     """
    168     ratio = float(freq_rec) / freq_reference
    169     if ratio > 1 + _FREQUENCY_THRESHOLD or ratio < 1 - _FREQUENCY_THRESHOLD:
    170         return True
    171     return False
    172 
    173 
    174 def _compare_frames(reference_file_frames, rec_file_frames, num_channels,
    175                     sample_rate):
    176     """Compares audio frames from the reference file and the recorded file.
    177 
    178     This method checks for two things:
    179       1. That the main frequency is the same in both the files. This is done
    180          using the FFT and observing the frequency corresponding to the
    181          peak.
    182       2. That there is no other dominant frequency in the recorded file.
    183          This is done by sweeping the frequency domain and checking that the
    184          frequency is always less than _FFT_NOISE_THRESHOLD percentage of
    185          the peak.
    186 
    187     The key assumption here is that the reference audio file contains only
    188     one frequency.
    189 
    190     @param reference_file_frames: Audio frames from the reference file.
    191     @param rec_file_frames: Audio frames from the recorded file.
    192     @param num_channels: Number of channels in the files.
    193     @param sample_rate: Sample rate of the files.
    194 
    195     @raise error.TestFail: The frequency of the recorded signal doesn't
    196                            match that of the reference signal.
    197     @raise error.TestFail: There is too much noise in the recorded signal.
    198     """
    199     for channel in range(num_channels):
    200         reference_data = reference_file_frames[channel::num_channels]
    201         rec_data = rec_file_frames[channel::num_channels]
    202 
    203         # Get fft and frequencies corresponding to the fft values.
    204         fft_reference = numpy.fft.rfft(reference_data)
    205         fft_rec = numpy.fft.rfft(rec_data)
    206         fft_freqs_reference = numpy.fft.rfftfreq(len(reference_data),
    207                                                  1.0 / sample_rate)
    208         fft_freqs_rec = numpy.fft.rfftfreq(len(rec_data), 1.0 / sample_rate)
    209 
    210         # Get frequency at highest peak.
    211         freq_reference = fft_freqs_reference[
    212                 numpy.argmax(numpy.abs(fft_reference))]
    213         abs_fft_rec = numpy.abs(fft_rec)
    214         freq_rec = fft_freqs_rec[numpy.argmax(abs_fft_rec)]
    215 
    216         # Compare the two frequencies.
    217         logging.info('Golden frequency of channel %i is %f', channel,
    218                      freq_reference)
    219         logging.info('Recorded frequency of channel %i is  %f', channel,
    220                      freq_rec)
    221         if _is_outside_frequency_threshold(freq_reference, freq_rec):
    222             raise error.TestFail('The recorded audio frequency does not match '
    223                                  'that of the audio played.')
    224 
    225         # Check for noise in the frequency domain.
    226         fft_rec_peak_val = numpy.max(abs_fft_rec)
    227         noise_detected = False
    228         for fft_index, fft_val in enumerate(abs_fft_rec):
    229             if _is_outside_frequency_threshold(freq_reference, freq_rec):
    230                 # If the frequency exceeds _FFT_NOISE_THRESHOLD, then fail.
    231                 if fft_val > _FFT_NOISE_THRESHOLD * fft_rec_peak_val:
    232                     logging.warning('Unexpected frequency peak detected at %f '
    233                                     'Hz.', fft_freqs_rec[fft_index])
    234                     noise_detected = True
    235 
    236         if noise_detected:
    237             raise error.TestFail('Signal is noiser than expected.')
    238 
    239 
    240 def compare_file(reference_audio_filename, test_audio_filename):
    241     """Compares the recorded audio file to the reference audio file.
    242 
    243     @param reference_audio_filename : Reference audio file containing the
    244                                       reference signal.
    245     @param test_audio_filename: Audio file containing audio captured from
    246                                 the test.
    247     """
    248     with contextlib.closing(wave.open(reference_audio_filename,
    249                                       'rb')) as reference_file:
    250         with contextlib.closing(wave.open(test_audio_filename,
    251                                           'rb')) as rec_file:
    252             # Extract data from files.
    253             reference_file_frames = extract_wav_frames(reference_file)
    254             rec_file_frames = extract_wav_frames(rec_file)
    255 
    256             num_channels = reference_file.getnchannels()
    257             _compare_frames(reference_file_frames, rec_file_frames,
    258                             reference_file.getnchannels(),
    259                             reference_file.getframerate())
    260