Home | History | Annotate | Download | only in speech_commands
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Runs a trained audio graph against a WAVE file and reports the results.
     16 
     17 The model, labels and .wav file specified in the arguments will be loaded, and
     18 then the predictions from running the model against the audio data will be
     19 printed to the console. This is a useful script for sanity checking trained
     20 models, and as an example of how to use an audio model from Python.
     21 
     22 Here's an example of running it:
     23 
     24 python tensorflow/examples/speech_commands/label_wav.py \
     25 --graph=/tmp/my_frozen_graph.pb \
     26 --labels=/tmp/speech_commands_train/conv_labels.txt \
     27 --wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
     28 
     29 """
     30 from __future__ import absolute_import
     31 from __future__ import division
     32 from __future__ import print_function
     33 
     34 import argparse
     35 import sys
     36 
     37 import tensorflow as tf
     38 
     39 # pylint: disable=unused-import
     40 from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
     41 # pylint: enable=unused-import
     42 
     43 FLAGS = None
     44 
     45 
     46 def load_graph(filename):
     47   """Unpersists graph from file as default graph."""
     48   with tf.gfile.FastGFile(filename, 'rb') as f:
     49     graph_def = tf.GraphDef()
     50     graph_def.ParseFromString(f.read())
     51     tf.import_graph_def(graph_def, name='')
     52 
     53 
     54 def load_labels(filename):
     55   """Read in labels, one label per line."""
     56   return [line.rstrip() for line in tf.gfile.GFile(filename)]
     57 
     58 
     59 def run_graph(wav_data, labels, input_layer_name, output_layer_name,
     60               num_top_predictions):
     61   """Runs the audio data through the graph and prints predictions."""
     62   with tf.Session() as sess:
     63     # Feed the audio data as input to the graph.
     64     #   predictions  will contain a two-dimensional array, where one
     65     #   dimension represents the input image count, and the other has
     66     #   predictions per class
     67     softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
     68     predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
     69 
     70     # Sort to show labels in order of confidence
     71     top_k = predictions.argsort()[-num_top_predictions:][::-1]
     72     for node_id in top_k:
     73       human_string = labels[node_id]
     74       score = predictions[node_id]
     75       print('%s (score = %.5f)' % (human_string, score))
     76 
     77     return 0
     78 
     79 
     80 def label_wav(wav, labels, graph, input_name, output_name, how_many_labels):
     81   """Loads the model and labels, and runs the inference to print predictions."""
     82   if not wav or not tf.gfile.Exists(wav):
     83     tf.logging.fatal('Audio file does not exist %s', wav)
     84 
     85   if not labels or not tf.gfile.Exists(labels):
     86     tf.logging.fatal('Labels file does not exist %s', labels)
     87 
     88   if not graph or not tf.gfile.Exists(graph):
     89     tf.logging.fatal('Graph file does not exist %s', graph)
     90 
     91   labels_list = load_labels(labels)
     92 
     93   # load graph, which is stored in the default session
     94   load_graph(graph)
     95 
     96   with open(wav, 'rb') as wav_file:
     97     wav_data = wav_file.read()
     98 
     99   run_graph(wav_data, labels_list, input_name, output_name, how_many_labels)
    100 
    101 
    102 def main(_):
    103   """Entry point for script, converts flags to arguments."""
    104   label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
    105             FLAGS.output_name, FLAGS.how_many_labels)
    106 
    107 
    108 if __name__ == '__main__':
    109   parser = argparse.ArgumentParser()
    110   parser.add_argument(
    111       '--wav', type=str, default='', help='Audio file to be identified.')
    112   parser.add_argument(
    113       '--graph', type=str, default='', help='Model to use for identification.')
    114   parser.add_argument(
    115       '--labels', type=str, default='', help='Path to file containing labels.')
    116   parser.add_argument(
    117       '--input_name',
    118       type=str,
    119       default='wav_data:0',
    120       help='Name of WAVE data input node in model.')
    121   parser.add_argument(
    122       '--output_name',
    123       type=str,
    124       default='labels_softmax:0',
    125       help='Name of node outputting a prediction in the model.')
    126   parser.add_argument(
    127       '--how_many_labels',
    128       type=int,
    129       default=3,
    130       help='Number of results to show.')
    131 
    132   FLAGS, unparsed = parser.parse_known_args()
    133   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    134