1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Model definitions for simple speech recognition. 16 17 """ 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import math 23 24 import tensorflow as tf 25 26 27 def prepare_model_settings(label_count, sample_rate, clip_duration_ms, 28 window_size_ms, window_stride_ms, 29 dct_coefficient_count): 30 """Calculates common settings needed for all models. 31 32 Args: 33 label_count: How many classes are to be recognized. 34 sample_rate: Number of audio samples per second. 35 clip_duration_ms: Length of each audio clip to be analyzed. 36 window_size_ms: Duration of frequency analysis window. 37 window_stride_ms: How far to move in time between frequency windows. 38 dct_coefficient_count: Number of frequency bins to use for analysis. 39 40 Returns: 41 Dictionary containing common settings. 42 """ 43 desired_samples = int(sample_rate * clip_duration_ms / 1000) 44 window_size_samples = int(sample_rate * window_size_ms / 1000) 45 window_stride_samples = int(sample_rate * window_stride_ms / 1000) 46 length_minus_window = (desired_samples - window_size_samples) 47 if length_minus_window < 0: 48 spectrogram_length = 0 49 else: 50 spectrogram_length = 1 + int(length_minus_window / window_stride_samples) 51 fingerprint_size = dct_coefficient_count * spectrogram_length 52 return { 53 'desired_samples': desired_samples, 54 'window_size_samples': window_size_samples, 55 'window_stride_samples': window_stride_samples, 56 'spectrogram_length': spectrogram_length, 57 'dct_coefficient_count': dct_coefficient_count, 58 'fingerprint_size': fingerprint_size, 59 'label_count': label_count, 60 'sample_rate': sample_rate, 61 } 62 63 64 def create_model(fingerprint_input, model_settings, model_architecture, 65 is_training, runtime_settings=None): 66 """Builds a model of the requested architecture compatible with the settings. 67 68 There are many possible ways of deriving predictions from a spectrogram 69 input, so this function provides an abstract interface for creating different 70 kinds of models in a black-box way. You need to pass in a TensorFlow node as 71 the 'fingerprint' input, and this should output a batch of 1D features that 72 describe the audio. Typically this will be derived from a spectrogram that's 73 been run through an MFCC, but in theory it can be any feature vector of the 74 size specified in model_settings['fingerprint_size']. 75 76 The function will build the graph it needs in the current TensorFlow graph, 77 and return the tensorflow output that will contain the 'logits' input to the 78 softmax prediction process. If training flag is on, it will also return a 79 placeholder node that can be used to control the dropout amount. 80 81 See the implementations below for the possible model architectures that can be 82 requested. 83 84 Args: 85 fingerprint_input: TensorFlow node that will output audio feature vectors. 86 model_settings: Dictionary of information about the model. 87 model_architecture: String specifying which kind of model to create. 88 is_training: Whether the model is going to be used for training. 89 runtime_settings: Dictionary of information about the runtime. 90 91 Returns: 92 TensorFlow node outputting logits results, and optionally a dropout 93 placeholder. 94 95 Raises: 96 Exception: If the architecture type isn't recognized. 97 """ 98 if model_architecture == 'single_fc': 99 return create_single_fc_model(fingerprint_input, model_settings, 100 is_training) 101 elif model_architecture == 'conv': 102 return create_conv_model(fingerprint_input, model_settings, is_training) 103 elif model_architecture == 'low_latency_conv': 104 return create_low_latency_conv_model(fingerprint_input, model_settings, 105 is_training) 106 elif model_architecture == 'low_latency_svdf': 107 return create_low_latency_svdf_model(fingerprint_input, model_settings, 108 is_training, runtime_settings) 109 else: 110 raise Exception('model_architecture argument "' + model_architecture + 111 '" not recognized, should be one of "single_fc", "conv",' + 112 ' "low_latency_conv, or "low_latency_svdf"') 113 114 115 def load_variables_from_checkpoint(sess, start_checkpoint): 116 """Utility function to centralize checkpoint restoration. 117 118 Args: 119 sess: TensorFlow session. 120 start_checkpoint: Path to saved checkpoint on disk. 121 """ 122 saver = tf.train.Saver(tf.global_variables()) 123 saver.restore(sess, start_checkpoint) 124 125 126 def create_single_fc_model(fingerprint_input, model_settings, is_training): 127 """Builds a model with a single hidden fully-connected layer. 128 129 This is a very simple model with just one matmul and bias layer. As you'd 130 expect, it doesn't produce very accurate results, but it is very fast and 131 simple, so it's useful for sanity testing. 132 133 Here's the layout of the graph: 134 135 (fingerprint_input) 136 v 137 [MatMul]<-(weights) 138 v 139 [BiasAdd]<-(bias) 140 v 141 142 Args: 143 fingerprint_input: TensorFlow node that will output audio feature vectors. 144 model_settings: Dictionary of information about the model. 145 is_training: Whether the model is going to be used for training. 146 147 Returns: 148 TensorFlow node outputting logits results, and optionally a dropout 149 placeholder. 150 """ 151 if is_training: 152 dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') 153 fingerprint_size = model_settings['fingerprint_size'] 154 label_count = model_settings['label_count'] 155 weights = tf.Variable( 156 tf.truncated_normal([fingerprint_size, label_count], stddev=0.001)) 157 bias = tf.Variable(tf.zeros([label_count])) 158 logits = tf.matmul(fingerprint_input, weights) + bias 159 if is_training: 160 return logits, dropout_prob 161 else: 162 return logits 163 164 165 def create_conv_model(fingerprint_input, model_settings, is_training): 166 """Builds a standard convolutional model. 167 168 This is roughly the network labeled as 'cnn-trad-fpool3' in the 169 'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper: 170 http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf 171 172 Here's the layout of the graph: 173 174 (fingerprint_input) 175 v 176 [Conv2D]<-(weights) 177 v 178 [BiasAdd]<-(bias) 179 v 180 [Relu] 181 v 182 [MaxPool] 183 v 184 [Conv2D]<-(weights) 185 v 186 [BiasAdd]<-(bias) 187 v 188 [Relu] 189 v 190 [MaxPool] 191 v 192 [MatMul]<-(weights) 193 v 194 [BiasAdd]<-(bias) 195 v 196 197 This produces fairly good quality results, but can involve a large number of 198 weight parameters and computations. For a cheaper alternative from the same 199 paper with slightly less accuracy, see 'low_latency_conv' below. 200 201 During training, dropout nodes are introduced after each relu, controlled by a 202 placeholder. 203 204 Args: 205 fingerprint_input: TensorFlow node that will output audio feature vectors. 206 model_settings: Dictionary of information about the model. 207 is_training: Whether the model is going to be used for training. 208 209 Returns: 210 TensorFlow node outputting logits results, and optionally a dropout 211 placeholder. 212 """ 213 if is_training: 214 dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') 215 input_frequency_size = model_settings['dct_coefficient_count'] 216 input_time_size = model_settings['spectrogram_length'] 217 fingerprint_4d = tf.reshape(fingerprint_input, 218 [-1, input_time_size, input_frequency_size, 1]) 219 first_filter_width = 8 220 first_filter_height = 20 221 first_filter_count = 64 222 first_weights = tf.Variable( 223 tf.truncated_normal( 224 [first_filter_height, first_filter_width, 1, first_filter_count], 225 stddev=0.01)) 226 first_bias = tf.Variable(tf.zeros([first_filter_count])) 227 first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1], 228 'SAME') + first_bias 229 first_relu = tf.nn.relu(first_conv) 230 if is_training: 231 first_dropout = tf.nn.dropout(first_relu, dropout_prob) 232 else: 233 first_dropout = first_relu 234 max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 235 second_filter_width = 4 236 second_filter_height = 10 237 second_filter_count = 64 238 second_weights = tf.Variable( 239 tf.truncated_normal( 240 [ 241 second_filter_height, second_filter_width, first_filter_count, 242 second_filter_count 243 ], 244 stddev=0.01)) 245 second_bias = tf.Variable(tf.zeros([second_filter_count])) 246 second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1], 247 'SAME') + second_bias 248 second_relu = tf.nn.relu(second_conv) 249 if is_training: 250 second_dropout = tf.nn.dropout(second_relu, dropout_prob) 251 else: 252 second_dropout = second_relu 253 second_conv_shape = second_dropout.get_shape() 254 second_conv_output_width = second_conv_shape[2] 255 second_conv_output_height = second_conv_shape[1] 256 second_conv_element_count = int( 257 second_conv_output_width * second_conv_output_height * 258 second_filter_count) 259 flattened_second_conv = tf.reshape(second_dropout, 260 [-1, second_conv_element_count]) 261 label_count = model_settings['label_count'] 262 final_fc_weights = tf.Variable( 263 tf.truncated_normal( 264 [second_conv_element_count, label_count], stddev=0.01)) 265 final_fc_bias = tf.Variable(tf.zeros([label_count])) 266 final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias 267 if is_training: 268 return final_fc, dropout_prob 269 else: 270 return final_fc 271 272 273 def create_low_latency_conv_model(fingerprint_input, model_settings, 274 is_training): 275 """Builds a convolutional model with low compute requirements. 276 277 This is roughly the network labeled as 'cnn-one-fstride4' in the 278 'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper: 279 http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf 280 281 Here's the layout of the graph: 282 283 (fingerprint_input) 284 v 285 [Conv2D]<-(weights) 286 v 287 [BiasAdd]<-(bias) 288 v 289 [Relu] 290 v 291 [MatMul]<-(weights) 292 v 293 [BiasAdd]<-(bias) 294 v 295 [MatMul]<-(weights) 296 v 297 [BiasAdd]<-(bias) 298 v 299 [MatMul]<-(weights) 300 v 301 [BiasAdd]<-(bias) 302 v 303 304 This produces slightly lower quality results than the 'conv' model, but needs 305 fewer weight parameters and computations. 306 307 During training, dropout nodes are introduced after the relu, controlled by a 308 placeholder. 309 310 Args: 311 fingerprint_input: TensorFlow node that will output audio feature vectors. 312 model_settings: Dictionary of information about the model. 313 is_training: Whether the model is going to be used for training. 314 315 Returns: 316 TensorFlow node outputting logits results, and optionally a dropout 317 placeholder. 318 """ 319 if is_training: 320 dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') 321 input_frequency_size = model_settings['dct_coefficient_count'] 322 input_time_size = model_settings['spectrogram_length'] 323 fingerprint_4d = tf.reshape(fingerprint_input, 324 [-1, input_time_size, input_frequency_size, 1]) 325 first_filter_width = 8 326 first_filter_height = input_time_size 327 first_filter_count = 186 328 first_filter_stride_x = 1 329 first_filter_stride_y = 1 330 first_weights = tf.Variable( 331 tf.truncated_normal( 332 [first_filter_height, first_filter_width, 1, first_filter_count], 333 stddev=0.01)) 334 first_bias = tf.Variable(tf.zeros([first_filter_count])) 335 first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [ 336 1, first_filter_stride_y, first_filter_stride_x, 1 337 ], 'VALID') + first_bias 338 first_relu = tf.nn.relu(first_conv) 339 if is_training: 340 first_dropout = tf.nn.dropout(first_relu, dropout_prob) 341 else: 342 first_dropout = first_relu 343 first_conv_output_width = math.floor( 344 (input_frequency_size - first_filter_width + first_filter_stride_x) / 345 first_filter_stride_x) 346 first_conv_output_height = math.floor( 347 (input_time_size - first_filter_height + first_filter_stride_y) / 348 first_filter_stride_y) 349 first_conv_element_count = int( 350 first_conv_output_width * first_conv_output_height * first_filter_count) 351 flattened_first_conv = tf.reshape(first_dropout, 352 [-1, first_conv_element_count]) 353 first_fc_output_channels = 128 354 first_fc_weights = tf.Variable( 355 tf.truncated_normal( 356 [first_conv_element_count, first_fc_output_channels], stddev=0.01)) 357 first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels])) 358 first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias 359 if is_training: 360 second_fc_input = tf.nn.dropout(first_fc, dropout_prob) 361 else: 362 second_fc_input = first_fc 363 second_fc_output_channels = 128 364 second_fc_weights = tf.Variable( 365 tf.truncated_normal( 366 [first_fc_output_channels, second_fc_output_channels], stddev=0.01)) 367 second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels])) 368 second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias 369 if is_training: 370 final_fc_input = tf.nn.dropout(second_fc, dropout_prob) 371 else: 372 final_fc_input = second_fc 373 label_count = model_settings['label_count'] 374 final_fc_weights = tf.Variable( 375 tf.truncated_normal( 376 [second_fc_output_channels, label_count], stddev=0.01)) 377 final_fc_bias = tf.Variable(tf.zeros([label_count])) 378 final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias 379 if is_training: 380 return final_fc, dropout_prob 381 else: 382 return final_fc 383 384 385 def create_low_latency_svdf_model(fingerprint_input, model_settings, 386 is_training, runtime_settings): 387 """Builds an SVDF model with low compute requirements. 388 389 This is based in the topology presented in the 'Compressing Deep Neural 390 Networks using a Rank-Constrained Topology' paper: 391 https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf 392 393 Here's the layout of the graph: 394 395 (fingerprint_input) 396 v 397 [SVDF]<-(weights) 398 v 399 [BiasAdd]<-(bias) 400 v 401 [Relu] 402 v 403 [MatMul]<-(weights) 404 v 405 [BiasAdd]<-(bias) 406 v 407 [MatMul]<-(weights) 408 v 409 [BiasAdd]<-(bias) 410 v 411 [MatMul]<-(weights) 412 v 413 [BiasAdd]<-(bias) 414 v 415 416 This model produces lower recognition accuracy than the 'conv' model above, 417 but requires fewer weight parameters and, significantly fewer computations. 418 419 During training, dropout nodes are introduced after the relu, controlled by a 420 placeholder. 421 422 Args: 423 fingerprint_input: TensorFlow node that will output audio feature vectors. 424 The node is expected to produce a 2D Tensor of shape: 425 [batch, model_settings['dct_coefficient_count'] * 426 model_settings['spectrogram_length']] 427 with the features corresponding to the same time slot arranged contiguously, 428 and the oldest slot at index [:, 0], and newest at [:, -1]. 429 model_settings: Dictionary of information about the model. 430 is_training: Whether the model is going to be used for training. 431 runtime_settings: Dictionary of information about the runtime. 432 433 Returns: 434 TensorFlow node outputting logits results, and optionally a dropout 435 placeholder. 436 437 Raises: 438 ValueError: If the inputs tensor is incorrectly shaped. 439 """ 440 if is_training: 441 dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') 442 443 input_frequency_size = model_settings['dct_coefficient_count'] 444 input_time_size = model_settings['spectrogram_length'] 445 446 # Validation. 447 input_shape = fingerprint_input.get_shape() 448 if len(input_shape) != 2: 449 raise ValueError('Inputs to `SVDF` should have rank == 2.') 450 if input_shape[-1].value is None: 451 raise ValueError('The last dimension of the inputs to `SVDF` ' 452 'should be defined. Found `None`.') 453 if input_shape[-1].value % input_frequency_size != 0: 454 raise ValueError('Inputs feature dimension %d must be a multiple of ' 455 'frame size %d', fingerprint_input.shape[-1].value, 456 input_frequency_size) 457 458 # Set number of units (i.e. nodes) and rank. 459 rank = 2 460 num_units = 1280 461 # Number of filters: pairs of feature and time filters. 462 num_filters = rank * num_units 463 # Create the runtime memory: [num_filters, batch, input_time_size] 464 batch = 1 465 memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]), 466 trainable=False, name='runtime-memory') 467 # Determine the number of new frames in the input, such that we only operate 468 # on those. For training we do not use the memory, and thus use all frames 469 # provided in the input. 470 # new_fingerprint_input: [batch, num_new_frames*input_frequency_size] 471 if is_training: 472 num_new_frames = input_time_size 473 else: 474 window_stride_ms = int(model_settings['window_stride_samples'] * 1000 / 475 model_settings['sample_rate']) 476 num_new_frames = tf.cond( 477 tf.equal(tf.count_nonzero(memory), 0), 478 lambda: input_time_size, 479 lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms)) 480 new_fingerprint_input = fingerprint_input[ 481 :, -num_new_frames*input_frequency_size:] 482 # Expand to add input channels dimension. 483 new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2) 484 485 # Create the frequency filters. 486 weights_frequency = tf.Variable( 487 tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01)) 488 # Expand to add input channels dimensions. 489 # weights_frequency: [input_frequency_size, 1, num_filters] 490 weights_frequency = tf.expand_dims(weights_frequency, 1) 491 # Convolve the 1D feature filters sliding over the time dimension. 492 # activations_time: [batch, num_new_frames, num_filters] 493 activations_time = tf.nn.conv1d( 494 new_fingerprint_input, weights_frequency, input_frequency_size, 'VALID') 495 # Rearrange such that we can perform the batched matmul. 496 # activations_time: [num_filters, batch, num_new_frames] 497 activations_time = tf.transpose(activations_time, perm=[2, 0, 1]) 498 499 # Runtime memory optimization. 500 if not is_training: 501 # We need to drop the activations corresponding to the oldest frames, and 502 # then add those corresponding to the new frames. 503 new_memory = memory[:, :, num_new_frames:] 504 new_memory = tf.concat([new_memory, activations_time], 2) 505 tf.assign(memory, new_memory) 506 activations_time = new_memory 507 508 # Create the time filters. 509 weights_time = tf.Variable( 510 tf.truncated_normal([num_filters, input_time_size], stddev=0.01)) 511 # Apply the time filter on the outputs of the feature filters. 512 # weights_time: [num_filters, input_time_size, 1] 513 # outputs: [num_filters, batch, 1] 514 weights_time = tf.expand_dims(weights_time, 2) 515 outputs = tf.matmul(activations_time, weights_time) 516 # Split num_units and rank into separate dimensions (the remaining 517 # dimension is the input_shape[0] -i.e. batch size). This also squeezes 518 # the last dimension, since it's not used. 519 # [num_filters, batch, 1] => [num_units, rank, batch] 520 outputs = tf.reshape(outputs, [num_units, rank, -1]) 521 # Sum the rank outputs per unit => [num_units, batch]. 522 units_output = tf.reduce_sum(outputs, axis=1) 523 # Transpose to shape [batch, num_units] 524 units_output = tf.transpose(units_output) 525 526 # Appy bias. 527 bias = tf.Variable(tf.zeros([num_units])) 528 first_bias = tf.nn.bias_add(units_output, bias) 529 530 # Relu. 531 first_relu = tf.nn.relu(first_bias) 532 533 if is_training: 534 first_dropout = tf.nn.dropout(first_relu, dropout_prob) 535 else: 536 first_dropout = first_relu 537 538 first_fc_output_channels = 256 539 first_fc_weights = tf.Variable( 540 tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01)) 541 first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels])) 542 first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias 543 if is_training: 544 second_fc_input = tf.nn.dropout(first_fc, dropout_prob) 545 else: 546 second_fc_input = first_fc 547 second_fc_output_channels = 256 548 second_fc_weights = tf.Variable( 549 tf.truncated_normal( 550 [first_fc_output_channels, second_fc_output_channels], stddev=0.01)) 551 second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels])) 552 second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias 553 if is_training: 554 final_fc_input = tf.nn.dropout(second_fc, dropout_prob) 555 else: 556 final_fc_input = second_fc 557 label_count = model_settings['label_count'] 558 final_fc_weights = tf.Variable( 559 tf.truncated_normal( 560 [second_fc_output_channels, label_count], stddev=0.01)) 561 final_fc_bias = tf.Variable(tf.zeros([label_count])) 562 final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias 563 if is_training: 564 return final_fc, dropout_prob 565 else: 566 return final_fc 567