Home | History | Annotate | Download | only in speech_commands
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Model definitions for simple speech recognition.
     16 
     17 """
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import math
     23 
     24 import tensorflow as tf
     25 
     26 
     27 def prepare_model_settings(label_count, sample_rate, clip_duration_ms,
     28                            window_size_ms, window_stride_ms,
     29                            dct_coefficient_count):
     30   """Calculates common settings needed for all models.
     31 
     32   Args:
     33     label_count: How many classes are to be recognized.
     34     sample_rate: Number of audio samples per second.
     35     clip_duration_ms: Length of each audio clip to be analyzed.
     36     window_size_ms: Duration of frequency analysis window.
     37     window_stride_ms: How far to move in time between frequency windows.
     38     dct_coefficient_count: Number of frequency bins to use for analysis.
     39 
     40   Returns:
     41     Dictionary containing common settings.
     42   """
     43   desired_samples = int(sample_rate * clip_duration_ms / 1000)
     44   window_size_samples = int(sample_rate * window_size_ms / 1000)
     45   window_stride_samples = int(sample_rate * window_stride_ms / 1000)
     46   length_minus_window = (desired_samples - window_size_samples)
     47   if length_minus_window < 0:
     48     spectrogram_length = 0
     49   else:
     50     spectrogram_length = 1 + int(length_minus_window / window_stride_samples)
     51   fingerprint_size = dct_coefficient_count * spectrogram_length
     52   return {
     53       'desired_samples': desired_samples,
     54       'window_size_samples': window_size_samples,
     55       'window_stride_samples': window_stride_samples,
     56       'spectrogram_length': spectrogram_length,
     57       'dct_coefficient_count': dct_coefficient_count,
     58       'fingerprint_size': fingerprint_size,
     59       'label_count': label_count,
     60       'sample_rate': sample_rate,
     61   }
     62 
     63 
     64 def create_model(fingerprint_input, model_settings, model_architecture,
     65                  is_training, runtime_settings=None):
     66   """Builds a model of the requested architecture compatible with the settings.
     67 
     68   There are many possible ways of deriving predictions from a spectrogram
     69   input, so this function provides an abstract interface for creating different
     70   kinds of models in a black-box way. You need to pass in a TensorFlow node as
     71   the 'fingerprint' input, and this should output a batch of 1D features that
     72   describe the audio. Typically this will be derived from a spectrogram that's
     73   been run through an MFCC, but in theory it can be any feature vector of the
     74   size specified in model_settings['fingerprint_size'].
     75 
     76   The function will build the graph it needs in the current TensorFlow graph,
     77   and return the tensorflow output that will contain the 'logits' input to the
     78   softmax prediction process. If training flag is on, it will also return a
     79   placeholder node that can be used to control the dropout amount.
     80 
     81   See the implementations below for the possible model architectures that can be
     82   requested.
     83 
     84   Args:
     85     fingerprint_input: TensorFlow node that will output audio feature vectors.
     86     model_settings: Dictionary of information about the model.
     87     model_architecture: String specifying which kind of model to create.
     88     is_training: Whether the model is going to be used for training.
     89     runtime_settings: Dictionary of information about the runtime.
     90 
     91   Returns:
     92     TensorFlow node outputting logits results, and optionally a dropout
     93     placeholder.
     94 
     95   Raises:
     96     Exception: If the architecture type isn't recognized.
     97   """
     98   if model_architecture == 'single_fc':
     99     return create_single_fc_model(fingerprint_input, model_settings,
    100                                   is_training)
    101   elif model_architecture == 'conv':
    102     return create_conv_model(fingerprint_input, model_settings, is_training)
    103   elif model_architecture == 'low_latency_conv':
    104     return create_low_latency_conv_model(fingerprint_input, model_settings,
    105                                          is_training)
    106   elif model_architecture == 'low_latency_svdf':
    107     return create_low_latency_svdf_model(fingerprint_input, model_settings,
    108                                          is_training, runtime_settings)
    109   else:
    110     raise Exception('model_architecture argument "' + model_architecture +
    111                     '" not recognized, should be one of "single_fc", "conv",' +
    112                     ' "low_latency_conv, or "low_latency_svdf"')
    113 
    114 
    115 def load_variables_from_checkpoint(sess, start_checkpoint):
    116   """Utility function to centralize checkpoint restoration.
    117 
    118   Args:
    119     sess: TensorFlow session.
    120     start_checkpoint: Path to saved checkpoint on disk.
    121   """
    122   saver = tf.train.Saver(tf.global_variables())
    123   saver.restore(sess, start_checkpoint)
    124 
    125 
    126 def create_single_fc_model(fingerprint_input, model_settings, is_training):
    127   """Builds a model with a single hidden fully-connected layer.
    128 
    129   This is a very simple model with just one matmul and bias layer. As you'd
    130   expect, it doesn't produce very accurate results, but it is very fast and
    131   simple, so it's useful for sanity testing.
    132 
    133   Here's the layout of the graph:
    134 
    135   (fingerprint_input)
    136           v
    137       [MatMul]<-(weights)
    138           v
    139       [BiasAdd]<-(bias)
    140           v
    141 
    142   Args:
    143     fingerprint_input: TensorFlow node that will output audio feature vectors.
    144     model_settings: Dictionary of information about the model.
    145     is_training: Whether the model is going to be used for training.
    146 
    147   Returns:
    148     TensorFlow node outputting logits results, and optionally a dropout
    149     placeholder.
    150   """
    151   if is_training:
    152     dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
    153   fingerprint_size = model_settings['fingerprint_size']
    154   label_count = model_settings['label_count']
    155   weights = tf.Variable(
    156       tf.truncated_normal([fingerprint_size, label_count], stddev=0.001))
    157   bias = tf.Variable(tf.zeros([label_count]))
    158   logits = tf.matmul(fingerprint_input, weights) + bias
    159   if is_training:
    160     return logits, dropout_prob
    161   else:
    162     return logits
    163 
    164 
    165 def create_conv_model(fingerprint_input, model_settings, is_training):
    166   """Builds a standard convolutional model.
    167 
    168   This is roughly the network labeled as 'cnn-trad-fpool3' in the
    169   'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
    170   http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
    171 
    172   Here's the layout of the graph:
    173 
    174   (fingerprint_input)
    175           v
    176       [Conv2D]<-(weights)
    177           v
    178       [BiasAdd]<-(bias)
    179           v
    180         [Relu]
    181           v
    182       [MaxPool]
    183           v
    184       [Conv2D]<-(weights)
    185           v
    186       [BiasAdd]<-(bias)
    187           v
    188         [Relu]
    189           v
    190       [MaxPool]
    191           v
    192       [MatMul]<-(weights)
    193           v
    194       [BiasAdd]<-(bias)
    195           v
    196 
    197   This produces fairly good quality results, but can involve a large number of
    198   weight parameters and computations. For a cheaper alternative from the same
    199   paper with slightly less accuracy, see 'low_latency_conv' below.
    200 
    201   During training, dropout nodes are introduced after each relu, controlled by a
    202   placeholder.
    203 
    204   Args:
    205     fingerprint_input: TensorFlow node that will output audio feature vectors.
    206     model_settings: Dictionary of information about the model.
    207     is_training: Whether the model is going to be used for training.
    208 
    209   Returns:
    210     TensorFlow node outputting logits results, and optionally a dropout
    211     placeholder.
    212   """
    213   if is_training:
    214     dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
    215   input_frequency_size = model_settings['dct_coefficient_count']
    216   input_time_size = model_settings['spectrogram_length']
    217   fingerprint_4d = tf.reshape(fingerprint_input,
    218                               [-1, input_time_size, input_frequency_size, 1])
    219   first_filter_width = 8
    220   first_filter_height = 20
    221   first_filter_count = 64
    222   first_weights = tf.Variable(
    223       tf.truncated_normal(
    224           [first_filter_height, first_filter_width, 1, first_filter_count],
    225           stddev=0.01))
    226   first_bias = tf.Variable(tf.zeros([first_filter_count]))
    227   first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
    228                             'SAME') + first_bias
    229   first_relu = tf.nn.relu(first_conv)
    230   if is_training:
    231     first_dropout = tf.nn.dropout(first_relu, dropout_prob)
    232   else:
    233     first_dropout = first_relu
    234   max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
    235   second_filter_width = 4
    236   second_filter_height = 10
    237   second_filter_count = 64
    238   second_weights = tf.Variable(
    239       tf.truncated_normal(
    240           [
    241               second_filter_height, second_filter_width, first_filter_count,
    242               second_filter_count
    243           ],
    244           stddev=0.01))
    245   second_bias = tf.Variable(tf.zeros([second_filter_count]))
    246   second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
    247                              'SAME') + second_bias
    248   second_relu = tf.nn.relu(second_conv)
    249   if is_training:
    250     second_dropout = tf.nn.dropout(second_relu, dropout_prob)
    251   else:
    252     second_dropout = second_relu
    253   second_conv_shape = second_dropout.get_shape()
    254   second_conv_output_width = second_conv_shape[2]
    255   second_conv_output_height = second_conv_shape[1]
    256   second_conv_element_count = int(
    257       second_conv_output_width * second_conv_output_height *
    258       second_filter_count)
    259   flattened_second_conv = tf.reshape(second_dropout,
    260                                      [-1, second_conv_element_count])
    261   label_count = model_settings['label_count']
    262   final_fc_weights = tf.Variable(
    263       tf.truncated_normal(
    264           [second_conv_element_count, label_count], stddev=0.01))
    265   final_fc_bias = tf.Variable(tf.zeros([label_count]))
    266   final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
    267   if is_training:
    268     return final_fc, dropout_prob
    269   else:
    270     return final_fc
    271 
    272 
    273 def create_low_latency_conv_model(fingerprint_input, model_settings,
    274                                   is_training):
    275   """Builds a convolutional model with low compute requirements.
    276 
    277   This is roughly the network labeled as 'cnn-one-fstride4' in the
    278   'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
    279   http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
    280 
    281   Here's the layout of the graph:
    282 
    283   (fingerprint_input)
    284           v
    285       [Conv2D]<-(weights)
    286           v
    287       [BiasAdd]<-(bias)
    288           v
    289         [Relu]
    290           v
    291       [MatMul]<-(weights)
    292           v
    293       [BiasAdd]<-(bias)
    294           v
    295       [MatMul]<-(weights)
    296           v
    297       [BiasAdd]<-(bias)
    298           v
    299       [MatMul]<-(weights)
    300           v
    301       [BiasAdd]<-(bias)
    302           v
    303 
    304   This produces slightly lower quality results than the 'conv' model, but needs
    305   fewer weight parameters and computations.
    306 
    307   During training, dropout nodes are introduced after the relu, controlled by a
    308   placeholder.
    309 
    310   Args:
    311     fingerprint_input: TensorFlow node that will output audio feature vectors.
    312     model_settings: Dictionary of information about the model.
    313     is_training: Whether the model is going to be used for training.
    314 
    315   Returns:
    316     TensorFlow node outputting logits results, and optionally a dropout
    317     placeholder.
    318   """
    319   if is_training:
    320     dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
    321   input_frequency_size = model_settings['dct_coefficient_count']
    322   input_time_size = model_settings['spectrogram_length']
    323   fingerprint_4d = tf.reshape(fingerprint_input,
    324                               [-1, input_time_size, input_frequency_size, 1])
    325   first_filter_width = 8
    326   first_filter_height = input_time_size
    327   first_filter_count = 186
    328   first_filter_stride_x = 1
    329   first_filter_stride_y = 1
    330   first_weights = tf.Variable(
    331       tf.truncated_normal(
    332           [first_filter_height, first_filter_width, 1, first_filter_count],
    333           stddev=0.01))
    334   first_bias = tf.Variable(tf.zeros([first_filter_count]))
    335   first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [
    336       1, first_filter_stride_y, first_filter_stride_x, 1
    337   ], 'VALID') + first_bias
    338   first_relu = tf.nn.relu(first_conv)
    339   if is_training:
    340     first_dropout = tf.nn.dropout(first_relu, dropout_prob)
    341   else:
    342     first_dropout = first_relu
    343   first_conv_output_width = math.floor(
    344       (input_frequency_size - first_filter_width + first_filter_stride_x) /
    345       first_filter_stride_x)
    346   first_conv_output_height = math.floor(
    347       (input_time_size - first_filter_height + first_filter_stride_y) /
    348       first_filter_stride_y)
    349   first_conv_element_count = int(
    350       first_conv_output_width * first_conv_output_height * first_filter_count)
    351   flattened_first_conv = tf.reshape(first_dropout,
    352                                     [-1, first_conv_element_count])
    353   first_fc_output_channels = 128
    354   first_fc_weights = tf.Variable(
    355       tf.truncated_normal(
    356           [first_conv_element_count, first_fc_output_channels], stddev=0.01))
    357   first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
    358   first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias
    359   if is_training:
    360     second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
    361   else:
    362     second_fc_input = first_fc
    363   second_fc_output_channels = 128
    364   second_fc_weights = tf.Variable(
    365       tf.truncated_normal(
    366           [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
    367   second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
    368   second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
    369   if is_training:
    370     final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
    371   else:
    372     final_fc_input = second_fc
    373   label_count = model_settings['label_count']
    374   final_fc_weights = tf.Variable(
    375       tf.truncated_normal(
    376           [second_fc_output_channels, label_count], stddev=0.01))
    377   final_fc_bias = tf.Variable(tf.zeros([label_count]))
    378   final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
    379   if is_training:
    380     return final_fc, dropout_prob
    381   else:
    382     return final_fc
    383 
    384 
    385 def create_low_latency_svdf_model(fingerprint_input, model_settings,
    386                                   is_training, runtime_settings):
    387   """Builds an SVDF model with low compute requirements.
    388 
    389   This is based in the topology presented in the 'Compressing Deep Neural
    390   Networks using a Rank-Constrained Topology' paper:
    391   https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf
    392 
    393   Here's the layout of the graph:
    394 
    395   (fingerprint_input)
    396           v
    397         [SVDF]<-(weights)
    398           v
    399       [BiasAdd]<-(bias)
    400           v
    401         [Relu]
    402           v
    403       [MatMul]<-(weights)
    404           v
    405       [BiasAdd]<-(bias)
    406           v
    407       [MatMul]<-(weights)
    408           v
    409       [BiasAdd]<-(bias)
    410           v
    411       [MatMul]<-(weights)
    412           v
    413       [BiasAdd]<-(bias)
    414           v
    415 
    416   This model produces lower recognition accuracy than the 'conv' model above,
    417   but requires fewer weight parameters and, significantly fewer computations.
    418 
    419   During training, dropout nodes are introduced after the relu, controlled by a
    420   placeholder.
    421 
    422   Args:
    423     fingerprint_input: TensorFlow node that will output audio feature vectors.
    424     The node is expected to produce a 2D Tensor of shape:
    425       [batch, model_settings['dct_coefficient_count'] *
    426               model_settings['spectrogram_length']]
    427     with the features corresponding to the same time slot arranged contiguously,
    428     and the oldest slot at index [:, 0], and newest at [:, -1].
    429     model_settings: Dictionary of information about the model.
    430     is_training: Whether the model is going to be used for training.
    431     runtime_settings: Dictionary of information about the runtime.
    432 
    433   Returns:
    434     TensorFlow node outputting logits results, and optionally a dropout
    435     placeholder.
    436 
    437   Raises:
    438       ValueError: If the inputs tensor is incorrectly shaped.
    439   """
    440   if is_training:
    441     dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
    442 
    443   input_frequency_size = model_settings['dct_coefficient_count']
    444   input_time_size = model_settings['spectrogram_length']
    445 
    446   # Validation.
    447   input_shape = fingerprint_input.get_shape()
    448   if len(input_shape) != 2:
    449     raise ValueError('Inputs to `SVDF` should have rank == 2.')
    450   if input_shape[-1].value is None:
    451     raise ValueError('The last dimension of the inputs to `SVDF` '
    452                      'should be defined. Found `None`.')
    453   if input_shape[-1].value % input_frequency_size != 0:
    454     raise ValueError('Inputs feature dimension %d must be a multiple of '
    455                      'frame size %d', fingerprint_input.shape[-1].value,
    456                      input_frequency_size)
    457 
    458   # Set number of units (i.e. nodes) and rank.
    459   rank = 2
    460   num_units = 1280
    461   # Number of filters: pairs of feature and time filters.
    462   num_filters = rank * num_units
    463   # Create the runtime memory: [num_filters, batch, input_time_size]
    464   batch = 1
    465   memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]),
    466                        trainable=False, name='runtime-memory')
    467   # Determine the number of new frames in the input, such that we only operate
    468   # on those. For training we do not use the memory, and thus use all frames
    469   # provided in the input.
    470   # new_fingerprint_input: [batch, num_new_frames*input_frequency_size]
    471   if is_training:
    472     num_new_frames = input_time_size
    473   else:
    474     window_stride_ms = int(model_settings['window_stride_samples'] * 1000 /
    475                            model_settings['sample_rate'])
    476     num_new_frames = tf.cond(
    477         tf.equal(tf.count_nonzero(memory), 0),
    478         lambda: input_time_size,
    479         lambda: int(runtime_settings['clip_stride_ms'] / window_stride_ms))
    480   new_fingerprint_input = fingerprint_input[
    481       :, -num_new_frames*input_frequency_size:]
    482   # Expand to add input channels dimension.
    483   new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2)
    484 
    485   # Create the frequency filters.
    486   weights_frequency = tf.Variable(
    487       tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01))
    488   # Expand to add input channels dimensions.
    489   # weights_frequency: [input_frequency_size, 1, num_filters]
    490   weights_frequency = tf.expand_dims(weights_frequency, 1)
    491   # Convolve the 1D feature filters sliding over the time dimension.
    492   # activations_time: [batch, num_new_frames, num_filters]
    493   activations_time = tf.nn.conv1d(
    494       new_fingerprint_input, weights_frequency, input_frequency_size, 'VALID')
    495   # Rearrange such that we can perform the batched matmul.
    496   # activations_time: [num_filters, batch, num_new_frames]
    497   activations_time = tf.transpose(activations_time, perm=[2, 0, 1])
    498 
    499   # Runtime memory optimization.
    500   if not is_training:
    501     # We need to drop the activations corresponding to the oldest frames, and
    502     # then add those corresponding to the new frames.
    503     new_memory = memory[:, :, num_new_frames:]
    504     new_memory = tf.concat([new_memory, activations_time], 2)
    505     tf.assign(memory, new_memory)
    506     activations_time = new_memory
    507 
    508   # Create the time filters.
    509   weights_time = tf.Variable(
    510       tf.truncated_normal([num_filters, input_time_size], stddev=0.01))
    511   # Apply the time filter on the outputs of the feature filters.
    512   # weights_time: [num_filters, input_time_size, 1]
    513   # outputs: [num_filters, batch, 1]
    514   weights_time = tf.expand_dims(weights_time, 2)
    515   outputs = tf.matmul(activations_time, weights_time)
    516   # Split num_units and rank into separate dimensions (the remaining
    517   # dimension is the input_shape[0] -i.e. batch size). This also squeezes
    518   # the last dimension, since it's not used.
    519   # [num_filters, batch, 1] => [num_units, rank, batch]
    520   outputs = tf.reshape(outputs, [num_units, rank, -1])
    521   # Sum the rank outputs per unit => [num_units, batch].
    522   units_output = tf.reduce_sum(outputs, axis=1)
    523   # Transpose to shape [batch, num_units]
    524   units_output = tf.transpose(units_output)
    525 
    526   # Appy bias.
    527   bias = tf.Variable(tf.zeros([num_units]))
    528   first_bias = tf.nn.bias_add(units_output, bias)
    529 
    530   # Relu.
    531   first_relu = tf.nn.relu(first_bias)
    532 
    533   if is_training:
    534     first_dropout = tf.nn.dropout(first_relu, dropout_prob)
    535   else:
    536     first_dropout = first_relu
    537 
    538   first_fc_output_channels = 256
    539   first_fc_weights = tf.Variable(
    540       tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01))
    541   first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels]))
    542   first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias
    543   if is_training:
    544     second_fc_input = tf.nn.dropout(first_fc, dropout_prob)
    545   else:
    546     second_fc_input = first_fc
    547   second_fc_output_channels = 256
    548   second_fc_weights = tf.Variable(
    549       tf.truncated_normal(
    550           [first_fc_output_channels, second_fc_output_channels], stddev=0.01))
    551   second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels]))
    552   second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias
    553   if is_training:
    554     final_fc_input = tf.nn.dropout(second_fc, dropout_prob)
    555   else:
    556     final_fc_input = second_fc
    557   label_count = model_settings['label_count']
    558   final_fc_weights = tf.Variable(
    559       tf.truncated_normal(
    560           [second_fc_output_channels, label_count], stddev=0.01))
    561   final_fc_bias = tf.Variable(tf.zeros([label_count]))
    562   final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias
    563   if is_training:
    564     return final_fc, dropout_prob
    565   else:
    566     return final_fc
    567