1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Model definitions for simple speech recognition. 16 17 """ 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import hashlib 23 import math 24 import os.path 25 import random 26 import re 27 import sys 28 import tarfile 29 30 import numpy as np 31 from six.moves import urllib 32 from six.moves import xrange # pylint: disable=redefined-builtin 33 import tensorflow as tf 34 35 from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio 36 from tensorflow.python.ops import io_ops 37 from tensorflow.python.platform import gfile 38 from tensorflow.python.util import compat 39 40 MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M 41 SILENCE_LABEL = '_silence_' 42 SILENCE_INDEX = 0 43 UNKNOWN_WORD_LABEL = '_unknown_' 44 UNKNOWN_WORD_INDEX = 1 45 BACKGROUND_NOISE_DIR_NAME = '_background_noise_' 46 RANDOM_SEED = 59185 47 48 49 def prepare_words_list(wanted_words): 50 """Prepends common tokens to the custom word list. 51 52 Args: 53 wanted_words: List of strings containing the custom words. 54 55 Returns: 56 List with the standard silence and unknown tokens added. 57 """ 58 return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words 59 60 61 def which_set(filename, validation_percentage, testing_percentage): 62 """Determines which data partition the file should belong to. 63 64 We want to keep files in the same training, validation, or testing sets even 65 if new ones are added over time. This makes it less likely that testing 66 samples will accidentally be reused in training when long runs are restarted 67 for example. To keep this stability, a hash of the filename is taken and used 68 to determine which set it should belong to. This determination only depends on 69 the name and the set proportions, so it won't change as other files are added. 70 71 It's also useful to associate particular files as related (for example words 72 spoken by the same person), so anything after '_nohash_' in a filename is 73 ignored for set determination. This ensures that 'bobby_nohash_0.wav' and 74 'bobby_nohash_1.wav' are always in the same set, for example. 75 76 Args: 77 filename: File path of the data sample. 78 validation_percentage: How much of the data set to use for validation. 79 testing_percentage: How much of the data set to use for testing. 80 81 Returns: 82 String, one of 'training', 'validation', or 'testing'. 83 """ 84 base_name = os.path.basename(filename) 85 # We want to ignore anything after '_nohash_' in the file name when 86 # deciding which set to put a wav in, so the data set creator has a way of 87 # grouping wavs that are close variations of each other. 88 hash_name = re.sub(r'_nohash_.*$', '', base_name) 89 # This looks a bit magical, but we need to decide whether this file should 90 # go into the training, testing, or validation sets, and we want to keep 91 # existing files in the same set even if more files are subsequently 92 # added. 93 # To do that, we need a stable way of deciding based on just the file name 94 # itself, so we do a hash of that and then use that to generate a 95 # probability value that we use to assign it. 96 hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 97 percentage_hash = ((int(hash_name_hashed, 16) % 98 (MAX_NUM_WAVS_PER_CLASS + 1)) * 99 (100.0 / MAX_NUM_WAVS_PER_CLASS)) 100 if percentage_hash < validation_percentage: 101 result = 'validation' 102 elif percentage_hash < (testing_percentage + validation_percentage): 103 result = 'testing' 104 else: 105 result = 'training' 106 return result 107 108 109 def load_wav_file(filename): 110 """Loads an audio file and returns a float PCM-encoded array of samples. 111 112 Args: 113 filename: Path to the .wav file to load. 114 115 Returns: 116 Numpy array holding the sample data as floats between -1.0 and 1.0. 117 """ 118 with tf.Session(graph=tf.Graph()) as sess: 119 wav_filename_placeholder = tf.placeholder(tf.string, []) 120 wav_loader = io_ops.read_file(wav_filename_placeholder) 121 wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1) 122 return sess.run( 123 wav_decoder, 124 feed_dict={wav_filename_placeholder: filename}).audio.flatten() 125 126 127 def save_wav_file(filename, wav_data, sample_rate): 128 """Saves audio sample data to a .wav audio file. 129 130 Args: 131 filename: Path to save the file to. 132 wav_data: 2D array of float PCM-encoded audio data. 133 sample_rate: Samples per second to encode in the file. 134 """ 135 with tf.Session(graph=tf.Graph()) as sess: 136 wav_filename_placeholder = tf.placeholder(tf.string, []) 137 sample_rate_placeholder = tf.placeholder(tf.int32, []) 138 wav_data_placeholder = tf.placeholder(tf.float32, [None, 1]) 139 wav_encoder = contrib_audio.encode_wav(wav_data_placeholder, 140 sample_rate_placeholder) 141 wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder) 142 sess.run( 143 wav_saver, 144 feed_dict={ 145 wav_filename_placeholder: filename, 146 sample_rate_placeholder: sample_rate, 147 wav_data_placeholder: np.reshape(wav_data, (-1, 1)) 148 }) 149 150 151 class AudioProcessor(object): 152 """Handles loading, partitioning, and preparing audio training data.""" 153 154 def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, 155 wanted_words, validation_percentage, testing_percentage, 156 model_settings): 157 self.data_dir = data_dir 158 self.maybe_download_and_extract_dataset(data_url, data_dir) 159 self.prepare_data_index(silence_percentage, unknown_percentage, 160 wanted_words, validation_percentage, 161 testing_percentage) 162 self.prepare_background_data() 163 self.prepare_processing_graph(model_settings) 164 165 def maybe_download_and_extract_dataset(self, data_url, dest_directory): 166 """Download and extract data set tar file. 167 168 If the data set we're using doesn't already exist, this function 169 downloads it from the TensorFlow.org website and unpacks it into a 170 directory. 171 If the data_url is none, don't download anything and expect the data 172 directory to contain the correct files already. 173 174 Args: 175 data_url: Web location of the tar file containing the data set. 176 dest_directory: File path to extract data to. 177 """ 178 if not data_url: 179 return 180 if not os.path.exists(dest_directory): 181 os.makedirs(dest_directory) 182 filename = data_url.split('/')[-1] 183 filepath = os.path.join(dest_directory, filename) 184 if not os.path.exists(filepath): 185 186 def _progress(count, block_size, total_size): 187 sys.stdout.write( 188 '\r>> Downloading %s %.1f%%' % 189 (filename, float(count * block_size) / float(total_size) * 100.0)) 190 sys.stdout.flush() 191 192 try: 193 filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) 194 except: 195 tf.logging.error('Failed to download URL: %s to folder: %s', data_url, 196 filepath) 197 tf.logging.error('Please make sure you have enough free space and' 198 ' an internet connection') 199 raise 200 print() 201 statinfo = os.stat(filepath) 202 tf.logging.info('Successfully downloaded %s (%d bytes)', filename, 203 statinfo.st_size) 204 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 205 206 def prepare_data_index(self, silence_percentage, unknown_percentage, 207 wanted_words, validation_percentage, 208 testing_percentage): 209 """Prepares a list of the samples organized by set and label. 210 211 The training loop needs a list of all the available data, organized by 212 which partition it should belong to, and with ground truth labels attached. 213 This function analyzes the folders below the `data_dir`, figures out the 214 right 215 labels for each file based on the name of the subdirectory it belongs to, 216 and uses a stable hash to assign it to a data set partition. 217 218 Args: 219 silence_percentage: How much of the resulting data should be background. 220 unknown_percentage: How much should be audio outside the wanted classes. 221 wanted_words: Labels of the classes we want to be able to recognize. 222 validation_percentage: How much of the data set to use for validation. 223 testing_percentage: How much of the data set to use for testing. 224 225 Returns: 226 Dictionary containing a list of file information for each set partition, 227 and a lookup map for each class to determine its numeric index. 228 229 Raises: 230 Exception: If expected files are not found. 231 """ 232 # Make sure the shuffling and picking of unknowns is deterministic. 233 random.seed(RANDOM_SEED) 234 wanted_words_index = {} 235 for index, wanted_word in enumerate(wanted_words): 236 wanted_words_index[wanted_word] = index + 2 237 self.data_index = {'validation': [], 'testing': [], 'training': []} 238 unknown_index = {'validation': [], 'testing': [], 'training': []} 239 all_words = {} 240 # Look through all the subfolders to find audio samples 241 search_path = os.path.join(self.data_dir, '*', '*.wav') 242 for wav_path in gfile.Glob(search_path): 243 _, word = os.path.split(os.path.dirname(wav_path)) 244 word = word.lower() 245 # Treat the '_background_noise_' folder as a special case, since we expect 246 # it to contain long audio samples we mix in to improve training. 247 if word == BACKGROUND_NOISE_DIR_NAME: 248 continue 249 all_words[word] = True 250 set_index = which_set(wav_path, validation_percentage, testing_percentage) 251 # If it's a known class, store its detail, otherwise add it to the list 252 # we'll use to train the unknown label. 253 if word in wanted_words_index: 254 self.data_index[set_index].append({'label': word, 'file': wav_path}) 255 else: 256 unknown_index[set_index].append({'label': word, 'file': wav_path}) 257 if not all_words: 258 raise Exception('No .wavs found at ' + search_path) 259 for index, wanted_word in enumerate(wanted_words): 260 if wanted_word not in all_words: 261 raise Exception('Expected to find ' + wanted_word + 262 ' in labels but only found ' + 263 ', '.join(all_words.keys())) 264 # We need an arbitrary file to load as the input for the silence samples. 265 # It's multiplied by zero later, so the content doesn't matter. 266 silence_wav_path = self.data_index['training'][0]['file'] 267 for set_index in ['validation', 'testing', 'training']: 268 set_size = len(self.data_index[set_index]) 269 silence_size = int(math.ceil(set_size * silence_percentage / 100)) 270 for _ in range(silence_size): 271 self.data_index[set_index].append({ 272 'label': SILENCE_LABEL, 273 'file': silence_wav_path 274 }) 275 # Pick some unknowns to add to each partition of the data set. 276 random.shuffle(unknown_index[set_index]) 277 unknown_size = int(math.ceil(set_size * unknown_percentage / 100)) 278 self.data_index[set_index].extend(unknown_index[set_index][:unknown_size]) 279 # Make sure the ordering is random. 280 for set_index in ['validation', 'testing', 'training']: 281 random.shuffle(self.data_index[set_index]) 282 # Prepare the rest of the result data structure. 283 self.words_list = prepare_words_list(wanted_words) 284 self.word_to_index = {} 285 for word in all_words: 286 if word in wanted_words_index: 287 self.word_to_index[word] = wanted_words_index[word] 288 else: 289 self.word_to_index[word] = UNKNOWN_WORD_INDEX 290 self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX 291 292 def prepare_background_data(self): 293 """Searches a folder for background noise audio, and loads it into memory. 294 295 It's expected that the background audio samples will be in a subdirectory 296 named '_background_noise_' inside the 'data_dir' folder, as .wavs that match 297 the sample rate of the training data, but can be much longer in duration. 298 299 If the '_background_noise_' folder doesn't exist at all, this isn't an 300 error, it's just taken to mean that no background noise augmentation should 301 be used. If the folder does exist, but it's empty, that's treated as an 302 error. 303 304 Returns: 305 List of raw PCM-encoded audio samples of background noise. 306 307 Raises: 308 Exception: If files aren't found in the folder. 309 """ 310 self.background_data = [] 311 background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME) 312 if not os.path.exists(background_dir): 313 return self.background_data 314 with tf.Session(graph=tf.Graph()) as sess: 315 wav_filename_placeholder = tf.placeholder(tf.string, []) 316 wav_loader = io_ops.read_file(wav_filename_placeholder) 317 wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1) 318 search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME, 319 '*.wav') 320 for wav_path in gfile.Glob(search_path): 321 wav_data = sess.run( 322 wav_decoder, 323 feed_dict={wav_filename_placeholder: wav_path}).audio.flatten() 324 self.background_data.append(wav_data) 325 if not self.background_data: 326 raise Exception('No background wav files were found in ' + search_path) 327 328 def prepare_processing_graph(self, model_settings): 329 """Builds a TensorFlow graph to apply the input distortions. 330 331 Creates a graph that loads a WAVE file, decodes it, scales the volume, 332 shifts it in time, adds in background noise, calculates a spectrogram, and 333 then builds an MFCC fingerprint from that. 334 335 This must be called with an active TensorFlow session running, and it 336 creates multiple placeholder inputs, and one output: 337 338 - wav_filename_placeholder_: Filename of the WAV to load. 339 - foreground_volume_placeholder_: How loud the main clip should be. 340 - time_shift_padding_placeholder_: Where to pad the clip. 341 - time_shift_offset_placeholder_: How much to move the clip in time. 342 - background_data_placeholder_: PCM sample data for background noise. 343 - background_volume_placeholder_: Loudness of mixed-in background. 344 - mfcc_: Output 2D fingerprint of processed audio. 345 346 Args: 347 model_settings: Information about the current model being trained. 348 """ 349 desired_samples = model_settings['desired_samples'] 350 self.wav_filename_placeholder_ = tf.placeholder(tf.string, []) 351 wav_loader = io_ops.read_file(self.wav_filename_placeholder_) 352 wav_decoder = contrib_audio.decode_wav( 353 wav_loader, desired_channels=1, desired_samples=desired_samples) 354 # Allow the audio sample's volume to be adjusted. 355 self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, []) 356 scaled_foreground = tf.multiply(wav_decoder.audio, 357 self.foreground_volume_placeholder_) 358 # Shift the sample's start position, and pad any gaps with zeros. 359 self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2]) 360 self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2]) 361 padded_foreground = tf.pad( 362 scaled_foreground, 363 self.time_shift_padding_placeholder_, 364 mode='CONSTANT') 365 sliced_foreground = tf.slice(padded_foreground, 366 self.time_shift_offset_placeholder_, 367 [desired_samples, -1]) 368 # Mix in background noise. 369 self.background_data_placeholder_ = tf.placeholder(tf.float32, 370 [desired_samples, 1]) 371 self.background_volume_placeholder_ = tf.placeholder(tf.float32, []) 372 background_mul = tf.multiply(self.background_data_placeholder_, 373 self.background_volume_placeholder_) 374 background_add = tf.add(background_mul, sliced_foreground) 375 background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) 376 # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. 377 spectrogram = contrib_audio.audio_spectrogram( 378 background_clamp, 379 window_size=model_settings['window_size_samples'], 380 stride=model_settings['window_stride_samples'], 381 magnitude_squared=True) 382 self.mfcc_ = contrib_audio.mfcc( 383 spectrogram, 384 wav_decoder.sample_rate, 385 dct_coefficient_count=model_settings['dct_coefficient_count']) 386 387 def set_size(self, mode): 388 """Calculates the number of samples in the dataset partition. 389 390 Args: 391 mode: Which partition, must be 'training', 'validation', or 'testing'. 392 393 Returns: 394 Number of samples in the partition. 395 """ 396 return len(self.data_index[mode]) 397 398 def get_data(self, how_many, offset, model_settings, background_frequency, 399 background_volume_range, time_shift, mode, sess): 400 """Gather samples from the data set, applying transformations as needed. 401 402 When the mode is 'training', a random selection of samples will be returned, 403 otherwise the first N clips in the partition will be used. This ensures that 404 validation always uses the same samples, reducing noise in the metrics. 405 406 Args: 407 how_many: Desired number of samples to return. -1 means the entire 408 contents of this partition. 409 offset: Where to start when fetching deterministically. 410 model_settings: Information about the current model being trained. 411 background_frequency: How many clips will have background noise, 0.0 to 412 1.0. 413 background_volume_range: How loud the background noise will be. 414 time_shift: How much to randomly shift the clips by in time. 415 mode: Which partition to use, must be 'training', 'validation', or 416 'testing'. 417 sess: TensorFlow session that was active when processor was created. 418 419 Returns: 420 List of sample data for the transformed samples, and list of label indexes 421 """ 422 # Pick one of the partitions to choose samples from. 423 candidates = self.data_index[mode] 424 if how_many == -1: 425 sample_count = len(candidates) 426 else: 427 sample_count = max(0, min(how_many, len(candidates) - offset)) 428 # Data and labels will be populated and returned. 429 data = np.zeros((sample_count, model_settings['fingerprint_size'])) 430 labels = np.zeros(sample_count) 431 desired_samples = model_settings['desired_samples'] 432 use_background = self.background_data and (mode == 'training') 433 pick_deterministically = (mode != 'training') 434 # Use the processing graph we created earlier to repeatedly to generate the 435 # final output sample data we'll use in training. 436 for i in xrange(offset, offset + sample_count): 437 # Pick which audio sample to use. 438 if how_many == -1 or pick_deterministically: 439 sample_index = i 440 else: 441 sample_index = np.random.randint(len(candidates)) 442 sample = candidates[sample_index] 443 # If we're time shifting, set up the offset for this sample. 444 if time_shift > 0: 445 time_shift_amount = np.random.randint(-time_shift, time_shift) 446 else: 447 time_shift_amount = 0 448 if time_shift_amount > 0: 449 time_shift_padding = [[time_shift_amount, 0], [0, 0]] 450 time_shift_offset = [0, 0] 451 else: 452 time_shift_padding = [[0, -time_shift_amount], [0, 0]] 453 time_shift_offset = [-time_shift_amount, 0] 454 input_dict = { 455 self.wav_filename_placeholder_: sample['file'], 456 self.time_shift_padding_placeholder_: time_shift_padding, 457 self.time_shift_offset_placeholder_: time_shift_offset, 458 } 459 # Choose a section of background noise to mix in. 460 if use_background: 461 background_index = np.random.randint(len(self.background_data)) 462 background_samples = self.background_data[background_index] 463 background_offset = np.random.randint( 464 0, len(background_samples) - model_settings['desired_samples']) 465 background_clipped = background_samples[background_offset:( 466 background_offset + desired_samples)] 467 background_reshaped = background_clipped.reshape([desired_samples, 1]) 468 if np.random.uniform(0, 1) < background_frequency: 469 background_volume = np.random.uniform(0, background_volume_range) 470 else: 471 background_volume = 0 472 else: 473 background_reshaped = np.zeros([desired_samples, 1]) 474 background_volume = 0 475 input_dict[self.background_data_placeholder_] = background_reshaped 476 input_dict[self.background_volume_placeholder_] = background_volume 477 # If we want silence, mute out the main sample but leave the background. 478 if sample['label'] == SILENCE_LABEL: 479 input_dict[self.foreground_volume_placeholder_] = 0 480 else: 481 input_dict[self.foreground_volume_placeholder_] = 1 482 # Run the graph to produce the output audio. 483 data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() 484 label_index = self.word_to_index[sample['label']] 485 labels[i - offset] = label_index 486 return data, labels 487 488 def get_unprocessed_data(self, how_many, model_settings, mode): 489 """Retrieve sample data for the given partition, with no transformations. 490 491 Args: 492 how_many: Desired number of samples to return. -1 means the entire 493 contents of this partition. 494 model_settings: Information about the current model being trained. 495 mode: Which partition to use, must be 'training', 'validation', or 496 'testing'. 497 498 Returns: 499 List of sample data for the samples, and list of labels in one-hot form. 500 """ 501 candidates = self.data_index[mode] 502 if how_many == -1: 503 sample_count = len(candidates) 504 else: 505 sample_count = how_many 506 desired_samples = model_settings['desired_samples'] 507 words_list = self.words_list 508 data = np.zeros((sample_count, desired_samples)) 509 labels = [] 510 with tf.Session(graph=tf.Graph()) as sess: 511 wav_filename_placeholder = tf.placeholder(tf.string, []) 512 wav_loader = io_ops.read_file(wav_filename_placeholder) 513 wav_decoder = contrib_audio.decode_wav( 514 wav_loader, desired_channels=1, desired_samples=desired_samples) 515 foreground_volume_placeholder = tf.placeholder(tf.float32, []) 516 scaled_foreground = tf.multiply(wav_decoder.audio, 517 foreground_volume_placeholder) 518 for i in range(sample_count): 519 if how_many == -1: 520 sample_index = i 521 else: 522 sample_index = np.random.randint(len(candidates)) 523 sample = candidates[sample_index] 524 input_dict = {wav_filename_placeholder: sample['file']} 525 if sample['label'] == SILENCE_LABEL: 526 input_dict[foreground_volume_placeholder] = 0 527 else: 528 input_dict[foreground_volume_placeholder] = 1 529 data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten() 530 label_index = self.word_to_index[sample['label']] 531 labels.append(words_list[label_index]) 532 return data, labels 533