Home | History | Annotate | Download | only in speech_commands
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Model definitions for simple speech recognition.
     16 
     17 """
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import hashlib
     23 import math
     24 import os.path
     25 import random
     26 import re
     27 import sys
     28 import tarfile
     29 
     30 import numpy as np
     31 from six.moves import urllib
     32 from six.moves import xrange  # pylint: disable=redefined-builtin
     33 import tensorflow as tf
     34 
     35 from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
     36 from tensorflow.python.ops import io_ops
     37 from tensorflow.python.platform import gfile
     38 from tensorflow.python.util import compat
     39 
     40 MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  # ~134M
     41 SILENCE_LABEL = '_silence_'
     42 SILENCE_INDEX = 0
     43 UNKNOWN_WORD_LABEL = '_unknown_'
     44 UNKNOWN_WORD_INDEX = 1
     45 BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
     46 RANDOM_SEED = 59185
     47 
     48 
     49 def prepare_words_list(wanted_words):
     50   """Prepends common tokens to the custom word list.
     51 
     52   Args:
     53     wanted_words: List of strings containing the custom words.
     54 
     55   Returns:
     56     List with the standard silence and unknown tokens added.
     57   """
     58   return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
     59 
     60 
     61 def which_set(filename, validation_percentage, testing_percentage):
     62   """Determines which data partition the file should belong to.
     63 
     64   We want to keep files in the same training, validation, or testing sets even
     65   if new ones are added over time. This makes it less likely that testing
     66   samples will accidentally be reused in training when long runs are restarted
     67   for example. To keep this stability, a hash of the filename is taken and used
     68   to determine which set it should belong to. This determination only depends on
     69   the name and the set proportions, so it won't change as other files are added.
     70 
     71   It's also useful to associate particular files as related (for example words
     72   spoken by the same person), so anything after '_nohash_' in a filename is
     73   ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
     74   'bobby_nohash_1.wav' are always in the same set, for example.
     75 
     76   Args:
     77     filename: File path of the data sample.
     78     validation_percentage: How much of the data set to use for validation.
     79     testing_percentage: How much of the data set to use for testing.
     80 
     81   Returns:
     82     String, one of 'training', 'validation', or 'testing'.
     83   """
     84   base_name = os.path.basename(filename)
     85   # We want to ignore anything after '_nohash_' in the file name when
     86   # deciding which set to put a wav in, so the data set creator has a way of
     87   # grouping wavs that are close variations of each other.
     88   hash_name = re.sub(r'_nohash_.*$', '', base_name)
     89   # This looks a bit magical, but we need to decide whether this file should
     90   # go into the training, testing, or validation sets, and we want to keep
     91   # existing files in the same set even if more files are subsequently
     92   # added.
     93   # To do that, we need a stable way of deciding based on just the file name
     94   # itself, so we do a hash of that and then use that to generate a
     95   # probability value that we use to assign it.
     96   hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
     97   percentage_hash = ((int(hash_name_hashed, 16) %
     98                       (MAX_NUM_WAVS_PER_CLASS + 1)) *
     99                      (100.0 / MAX_NUM_WAVS_PER_CLASS))
    100   if percentage_hash < validation_percentage:
    101     result = 'validation'
    102   elif percentage_hash < (testing_percentage + validation_percentage):
    103     result = 'testing'
    104   else:
    105     result = 'training'
    106   return result
    107 
    108 
    109 def load_wav_file(filename):
    110   """Loads an audio file and returns a float PCM-encoded array of samples.
    111 
    112   Args:
    113     filename: Path to the .wav file to load.
    114 
    115   Returns:
    116     Numpy array holding the sample data as floats between -1.0 and 1.0.
    117   """
    118   with tf.Session(graph=tf.Graph()) as sess:
    119     wav_filename_placeholder = tf.placeholder(tf.string, [])
    120     wav_loader = io_ops.read_file(wav_filename_placeholder)
    121     wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
    122     return sess.run(
    123         wav_decoder,
    124         feed_dict={wav_filename_placeholder: filename}).audio.flatten()
    125 
    126 
    127 def save_wav_file(filename, wav_data, sample_rate):
    128   """Saves audio sample data to a .wav audio file.
    129 
    130   Args:
    131     filename: Path to save the file to.
    132     wav_data: 2D array of float PCM-encoded audio data.
    133     sample_rate: Samples per second to encode in the file.
    134   """
    135   with tf.Session(graph=tf.Graph()) as sess:
    136     wav_filename_placeholder = tf.placeholder(tf.string, [])
    137     sample_rate_placeholder = tf.placeholder(tf.int32, [])
    138     wav_data_placeholder = tf.placeholder(tf.float32, [None, 1])
    139     wav_encoder = contrib_audio.encode_wav(wav_data_placeholder,
    140                                            sample_rate_placeholder)
    141     wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
    142     sess.run(
    143         wav_saver,
    144         feed_dict={
    145             wav_filename_placeholder: filename,
    146             sample_rate_placeholder: sample_rate,
    147             wav_data_placeholder: np.reshape(wav_data, (-1, 1))
    148         })
    149 
    150 
    151 class AudioProcessor(object):
    152   """Handles loading, partitioning, and preparing audio training data."""
    153 
    154   def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
    155                wanted_words, validation_percentage, testing_percentage,
    156                model_settings):
    157     self.data_dir = data_dir
    158     self.maybe_download_and_extract_dataset(data_url, data_dir)
    159     self.prepare_data_index(silence_percentage, unknown_percentage,
    160                             wanted_words, validation_percentage,
    161                             testing_percentage)
    162     self.prepare_background_data()
    163     self.prepare_processing_graph(model_settings)
    164 
    165   def maybe_download_and_extract_dataset(self, data_url, dest_directory):
    166     """Download and extract data set tar file.
    167 
    168     If the data set we're using doesn't already exist, this function
    169     downloads it from the TensorFlow.org website and unpacks it into a
    170     directory.
    171     If the data_url is none, don't download anything and expect the data
    172     directory to contain the correct files already.
    173 
    174     Args:
    175       data_url: Web location of the tar file containing the data set.
    176       dest_directory: File path to extract data to.
    177     """
    178     if not data_url:
    179       return
    180     if not os.path.exists(dest_directory):
    181       os.makedirs(dest_directory)
    182     filename = data_url.split('/')[-1]
    183     filepath = os.path.join(dest_directory, filename)
    184     if not os.path.exists(filepath):
    185 
    186       def _progress(count, block_size, total_size):
    187         sys.stdout.write(
    188             '\r>> Downloading %s %.1f%%' %
    189             (filename, float(count * block_size) / float(total_size) * 100.0))
    190         sys.stdout.flush()
    191 
    192       try:
    193         filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
    194       except:
    195         tf.logging.error('Failed to download URL: %s to folder: %s', data_url,
    196                          filepath)
    197         tf.logging.error('Please make sure you have enough free space and'
    198                          ' an internet connection')
    199         raise
    200       print()
    201       statinfo = os.stat(filepath)
    202       tf.logging.info('Successfully downloaded %s (%d bytes)', filename,
    203                       statinfo.st_size)
    204     tarfile.open(filepath, 'r:gz').extractall(dest_directory)
    205 
    206   def prepare_data_index(self, silence_percentage, unknown_percentage,
    207                          wanted_words, validation_percentage,
    208                          testing_percentage):
    209     """Prepares a list of the samples organized by set and label.
    210 
    211     The training loop needs a list of all the available data, organized by
    212     which partition it should belong to, and with ground truth labels attached.
    213     This function analyzes the folders below the `data_dir`, figures out the
    214     right
    215     labels for each file based on the name of the subdirectory it belongs to,
    216     and uses a stable hash to assign it to a data set partition.
    217 
    218     Args:
    219       silence_percentage: How much of the resulting data should be background.
    220       unknown_percentage: How much should be audio outside the wanted classes.
    221       wanted_words: Labels of the classes we want to be able to recognize.
    222       validation_percentage: How much of the data set to use for validation.
    223       testing_percentage: How much of the data set to use for testing.
    224 
    225     Returns:
    226       Dictionary containing a list of file information for each set partition,
    227       and a lookup map for each class to determine its numeric index.
    228 
    229     Raises:
    230       Exception: If expected files are not found.
    231     """
    232     # Make sure the shuffling and picking of unknowns is deterministic.
    233     random.seed(RANDOM_SEED)
    234     wanted_words_index = {}
    235     for index, wanted_word in enumerate(wanted_words):
    236       wanted_words_index[wanted_word] = index + 2
    237     self.data_index = {'validation': [], 'testing': [], 'training': []}
    238     unknown_index = {'validation': [], 'testing': [], 'training': []}
    239     all_words = {}
    240     # Look through all the subfolders to find audio samples
    241     search_path = os.path.join(self.data_dir, '*', '*.wav')
    242     for wav_path in gfile.Glob(search_path):
    243       _, word = os.path.split(os.path.dirname(wav_path))
    244       word = word.lower()
    245       # Treat the '_background_noise_' folder as a special case, since we expect
    246       # it to contain long audio samples we mix in to improve training.
    247       if word == BACKGROUND_NOISE_DIR_NAME:
    248         continue
    249       all_words[word] = True
    250       set_index = which_set(wav_path, validation_percentage, testing_percentage)
    251       # If it's a known class, store its detail, otherwise add it to the list
    252       # we'll use to train the unknown label.
    253       if word in wanted_words_index:
    254         self.data_index[set_index].append({'label': word, 'file': wav_path})
    255       else:
    256         unknown_index[set_index].append({'label': word, 'file': wav_path})
    257     if not all_words:
    258       raise Exception('No .wavs found at ' + search_path)
    259     for index, wanted_word in enumerate(wanted_words):
    260       if wanted_word not in all_words:
    261         raise Exception('Expected to find ' + wanted_word +
    262                         ' in labels but only found ' +
    263                         ', '.join(all_words.keys()))
    264     # We need an arbitrary file to load as the input for the silence samples.
    265     # It's multiplied by zero later, so the content doesn't matter.
    266     silence_wav_path = self.data_index['training'][0]['file']
    267     for set_index in ['validation', 'testing', 'training']:
    268       set_size = len(self.data_index[set_index])
    269       silence_size = int(math.ceil(set_size * silence_percentage / 100))
    270       for _ in range(silence_size):
    271         self.data_index[set_index].append({
    272             'label': SILENCE_LABEL,
    273             'file': silence_wav_path
    274         })
    275       # Pick some unknowns to add to each partition of the data set.
    276       random.shuffle(unknown_index[set_index])
    277       unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
    278       self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
    279     # Make sure the ordering is random.
    280     for set_index in ['validation', 'testing', 'training']:
    281       random.shuffle(self.data_index[set_index])
    282     # Prepare the rest of the result data structure.
    283     self.words_list = prepare_words_list(wanted_words)
    284     self.word_to_index = {}
    285     for word in all_words:
    286       if word in wanted_words_index:
    287         self.word_to_index[word] = wanted_words_index[word]
    288       else:
    289         self.word_to_index[word] = UNKNOWN_WORD_INDEX
    290     self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
    291 
    292   def prepare_background_data(self):
    293     """Searches a folder for background noise audio, and loads it into memory.
    294 
    295     It's expected that the background audio samples will be in a subdirectory
    296     named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
    297     the sample rate of the training data, but can be much longer in duration.
    298 
    299     If the '_background_noise_' folder doesn't exist at all, this isn't an
    300     error, it's just taken to mean that no background noise augmentation should
    301     be used. If the folder does exist, but it's empty, that's treated as an
    302     error.
    303 
    304     Returns:
    305       List of raw PCM-encoded audio samples of background noise.
    306 
    307     Raises:
    308       Exception: If files aren't found in the folder.
    309     """
    310     self.background_data = []
    311     background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
    312     if not os.path.exists(background_dir):
    313       return self.background_data
    314     with tf.Session(graph=tf.Graph()) as sess:
    315       wav_filename_placeholder = tf.placeholder(tf.string, [])
    316       wav_loader = io_ops.read_file(wav_filename_placeholder)
    317       wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
    318       search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
    319                                  '*.wav')
    320       for wav_path in gfile.Glob(search_path):
    321         wav_data = sess.run(
    322             wav_decoder,
    323             feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
    324         self.background_data.append(wav_data)
    325       if not self.background_data:
    326         raise Exception('No background wav files were found in ' + search_path)
    327 
    328   def prepare_processing_graph(self, model_settings):
    329     """Builds a TensorFlow graph to apply the input distortions.
    330 
    331     Creates a graph that loads a WAVE file, decodes it, scales the volume,
    332     shifts it in time, adds in background noise, calculates a spectrogram, and
    333     then builds an MFCC fingerprint from that.
    334 
    335     This must be called with an active TensorFlow session running, and it
    336     creates multiple placeholder inputs, and one output:
    337 
    338       - wav_filename_placeholder_: Filename of the WAV to load.
    339       - foreground_volume_placeholder_: How loud the main clip should be.
    340       - time_shift_padding_placeholder_: Where to pad the clip.
    341       - time_shift_offset_placeholder_: How much to move the clip in time.
    342       - background_data_placeholder_: PCM sample data for background noise.
    343       - background_volume_placeholder_: Loudness of mixed-in background.
    344       - mfcc_: Output 2D fingerprint of processed audio.
    345 
    346     Args:
    347       model_settings: Information about the current model being trained.
    348     """
    349     desired_samples = model_settings['desired_samples']
    350     self.wav_filename_placeholder_ = tf.placeholder(tf.string, [])
    351     wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
    352     wav_decoder = contrib_audio.decode_wav(
    353         wav_loader, desired_channels=1, desired_samples=desired_samples)
    354     # Allow the audio sample's volume to be adjusted.
    355     self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, [])
    356     scaled_foreground = tf.multiply(wav_decoder.audio,
    357                                     self.foreground_volume_placeholder_)
    358     # Shift the sample's start position, and pad any gaps with zeros.
    359     self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2])
    360     self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2])
    361     padded_foreground = tf.pad(
    362         scaled_foreground,
    363         self.time_shift_padding_placeholder_,
    364         mode='CONSTANT')
    365     sliced_foreground = tf.slice(padded_foreground,
    366                                  self.time_shift_offset_placeholder_,
    367                                  [desired_samples, -1])
    368     # Mix in background noise.
    369     self.background_data_placeholder_ = tf.placeholder(tf.float32,
    370                                                        [desired_samples, 1])
    371     self.background_volume_placeholder_ = tf.placeholder(tf.float32, [])
    372     background_mul = tf.multiply(self.background_data_placeholder_,
    373                                  self.background_volume_placeholder_)
    374     background_add = tf.add(background_mul, sliced_foreground)
    375     background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
    376     # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
    377     spectrogram = contrib_audio.audio_spectrogram(
    378         background_clamp,
    379         window_size=model_settings['window_size_samples'],
    380         stride=model_settings['window_stride_samples'],
    381         magnitude_squared=True)
    382     self.mfcc_ = contrib_audio.mfcc(
    383         spectrogram,
    384         wav_decoder.sample_rate,
    385         dct_coefficient_count=model_settings['dct_coefficient_count'])
    386 
    387   def set_size(self, mode):
    388     """Calculates the number of samples in the dataset partition.
    389 
    390     Args:
    391       mode: Which partition, must be 'training', 'validation', or 'testing'.
    392 
    393     Returns:
    394       Number of samples in the partition.
    395     """
    396     return len(self.data_index[mode])
    397 
    398   def get_data(self, how_many, offset, model_settings, background_frequency,
    399                background_volume_range, time_shift, mode, sess):
    400     """Gather samples from the data set, applying transformations as needed.
    401 
    402     When the mode is 'training', a random selection of samples will be returned,
    403     otherwise the first N clips in the partition will be used. This ensures that
    404     validation always uses the same samples, reducing noise in the metrics.
    405 
    406     Args:
    407       how_many: Desired number of samples to return. -1 means the entire
    408         contents of this partition.
    409       offset: Where to start when fetching deterministically.
    410       model_settings: Information about the current model being trained.
    411       background_frequency: How many clips will have background noise, 0.0 to
    412         1.0.
    413       background_volume_range: How loud the background noise will be.
    414       time_shift: How much to randomly shift the clips by in time.
    415       mode: Which partition to use, must be 'training', 'validation', or
    416         'testing'.
    417       sess: TensorFlow session that was active when processor was created.
    418 
    419     Returns:
    420       List of sample data for the transformed samples, and list of label indexes
    421     """
    422     # Pick one of the partitions to choose samples from.
    423     candidates = self.data_index[mode]
    424     if how_many == -1:
    425       sample_count = len(candidates)
    426     else:
    427       sample_count = max(0, min(how_many, len(candidates) - offset))
    428     # Data and labels will be populated and returned.
    429     data = np.zeros((sample_count, model_settings['fingerprint_size']))
    430     labels = np.zeros(sample_count)
    431     desired_samples = model_settings['desired_samples']
    432     use_background = self.background_data and (mode == 'training')
    433     pick_deterministically = (mode != 'training')
    434     # Use the processing graph we created earlier to repeatedly to generate the
    435     # final output sample data we'll use in training.
    436     for i in xrange(offset, offset + sample_count):
    437       # Pick which audio sample to use.
    438       if how_many == -1 or pick_deterministically:
    439         sample_index = i
    440       else:
    441         sample_index = np.random.randint(len(candidates))
    442       sample = candidates[sample_index]
    443       # If we're time shifting, set up the offset for this sample.
    444       if time_shift > 0:
    445         time_shift_amount = np.random.randint(-time_shift, time_shift)
    446       else:
    447         time_shift_amount = 0
    448       if time_shift_amount > 0:
    449         time_shift_padding = [[time_shift_amount, 0], [0, 0]]
    450         time_shift_offset = [0, 0]
    451       else:
    452         time_shift_padding = [[0, -time_shift_amount], [0, 0]]
    453         time_shift_offset = [-time_shift_amount, 0]
    454       input_dict = {
    455           self.wav_filename_placeholder_: sample['file'],
    456           self.time_shift_padding_placeholder_: time_shift_padding,
    457           self.time_shift_offset_placeholder_: time_shift_offset,
    458       }
    459       # Choose a section of background noise to mix in.
    460       if use_background:
    461         background_index = np.random.randint(len(self.background_data))
    462         background_samples = self.background_data[background_index]
    463         background_offset = np.random.randint(
    464             0, len(background_samples) - model_settings['desired_samples'])
    465         background_clipped = background_samples[background_offset:(
    466             background_offset + desired_samples)]
    467         background_reshaped = background_clipped.reshape([desired_samples, 1])
    468         if np.random.uniform(0, 1) < background_frequency:
    469           background_volume = np.random.uniform(0, background_volume_range)
    470         else:
    471           background_volume = 0
    472       else:
    473         background_reshaped = np.zeros([desired_samples, 1])
    474         background_volume = 0
    475       input_dict[self.background_data_placeholder_] = background_reshaped
    476       input_dict[self.background_volume_placeholder_] = background_volume
    477       # If we want silence, mute out the main sample but leave the background.
    478       if sample['label'] == SILENCE_LABEL:
    479         input_dict[self.foreground_volume_placeholder_] = 0
    480       else:
    481         input_dict[self.foreground_volume_placeholder_] = 1
    482       # Run the graph to produce the output audio.
    483       data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten()
    484       label_index = self.word_to_index[sample['label']]
    485       labels[i - offset] = label_index
    486     return data, labels
    487 
    488   def get_unprocessed_data(self, how_many, model_settings, mode):
    489     """Retrieve sample data for the given partition, with no transformations.
    490 
    491     Args:
    492       how_many: Desired number of samples to return. -1 means the entire
    493         contents of this partition.
    494       model_settings: Information about the current model being trained.
    495       mode: Which partition to use, must be 'training', 'validation', or
    496         'testing'.
    497 
    498     Returns:
    499       List of sample data for the samples, and list of labels in one-hot form.
    500     """
    501     candidates = self.data_index[mode]
    502     if how_many == -1:
    503       sample_count = len(candidates)
    504     else:
    505       sample_count = how_many
    506     desired_samples = model_settings['desired_samples']
    507     words_list = self.words_list
    508     data = np.zeros((sample_count, desired_samples))
    509     labels = []
    510     with tf.Session(graph=tf.Graph()) as sess:
    511       wav_filename_placeholder = tf.placeholder(tf.string, [])
    512       wav_loader = io_ops.read_file(wav_filename_placeholder)
    513       wav_decoder = contrib_audio.decode_wav(
    514           wav_loader, desired_channels=1, desired_samples=desired_samples)
    515       foreground_volume_placeholder = tf.placeholder(tf.float32, [])
    516       scaled_foreground = tf.multiply(wav_decoder.audio,
    517                                       foreground_volume_placeholder)
    518       for i in range(sample_count):
    519         if how_many == -1:
    520           sample_index = i
    521         else:
    522           sample_index = np.random.randint(len(candidates))
    523         sample = candidates[sample_index]
    524         input_dict = {wav_filename_placeholder: sample['file']}
    525         if sample['label'] == SILENCE_LABEL:
    526           input_dict[foreground_volume_placeholder] = 0
    527         else:
    528           input_dict[foreground_volume_placeholder] = 1
    529         data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
    530         label_index = self.word_to_index[sample['label']]
    531         labels.append(words_list[label_index])
    532     return data, labels
    533