1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Runs a trained audio graph against a WAVE file and reports the results. 16 17 The model, labels and .wav file specified in the arguments will be loaded, and 18 then the predictions from running the model against the audio data will be 19 printed to the console. This is a useful script for sanity checking trained 20 models, and as an example of how to use an audio model from Python. 21 22 Here's an example of running it: 23 24 python tensorflow/examples/speech_commands/label_wav.py \ 25 --graph=/tmp/my_frozen_graph.pb \ 26 --labels=/tmp/speech_commands_train/conv_labels.txt \ 27 --wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav 28 29 """ 30 from __future__ import absolute_import 31 from __future__ import division 32 from __future__ import print_function 33 34 import argparse 35 import sys 36 37 import tensorflow as tf 38 39 # pylint: disable=unused-import 40 from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio 41 # pylint: enable=unused-import 42 43 FLAGS = None 44 45 46 def load_graph(filename): 47 """Unpersists graph from file as default graph.""" 48 with tf.gfile.FastGFile(filename, 'rb') as f: 49 graph_def = tf.GraphDef() 50 graph_def.ParseFromString(f.read()) 51 tf.import_graph_def(graph_def, name='') 52 53 54 def load_labels(filename): 55 """Read in labels, one label per line.""" 56 return [line.rstrip() for line in tf.gfile.GFile(filename)] 57 58 59 def run_graph(wav_data, labels, input_layer_name, output_layer_name, 60 num_top_predictions): 61 """Runs the audio data through the graph and prints predictions.""" 62 with tf.Session() as sess: 63 # Feed the audio data as input to the graph. 64 # predictions will contain a two-dimensional array, where one 65 # dimension represents the input image count, and the other has 66 # predictions per class 67 softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) 68 predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data}) 69 70 # Sort to show labels in order of confidence 71 top_k = predictions.argsort()[-num_top_predictions:][::-1] 72 for node_id in top_k: 73 human_string = labels[node_id] 74 score = predictions[node_id] 75 print('%s (score = %.5f)' % (human_string, score)) 76 77 return 0 78 79 80 def label_wav(wav, labels, graph, input_name, output_name, how_many_labels): 81 """Loads the model and labels, and runs the inference to print predictions.""" 82 if not wav or not tf.gfile.Exists(wav): 83 tf.logging.fatal('Audio file does not exist %s', wav) 84 85 if not labels or not tf.gfile.Exists(labels): 86 tf.logging.fatal('Labels file does not exist %s', labels) 87 88 if not graph or not tf.gfile.Exists(graph): 89 tf.logging.fatal('Graph file does not exist %s', graph) 90 91 labels_list = load_labels(labels) 92 93 # load graph, which is stored in the default session 94 load_graph(graph) 95 96 with open(wav, 'rb') as wav_file: 97 wav_data = wav_file.read() 98 99 run_graph(wav_data, labels_list, input_name, output_name, how_many_labels) 100 101 102 def main(_): 103 """Entry point for script, converts flags to arguments.""" 104 label_wav(FLAGS.wav, FLAGS.labels, FLAGS.graph, FLAGS.input_name, 105 FLAGS.output_name, FLAGS.how_many_labels) 106 107 108 if __name__ == '__main__': 109 parser = argparse.ArgumentParser() 110 parser.add_argument( 111 '--wav', type=str, default='', help='Audio file to be identified.') 112 parser.add_argument( 113 '--graph', type=str, default='', help='Model to use for identification.') 114 parser.add_argument( 115 '--labels', type=str, default='', help='Path to file containing labels.') 116 parser.add_argument( 117 '--input_name', 118 type=str, 119 default='wav_data:0', 120 help='Name of WAVE data input node in model.') 121 parser.add_argument( 122 '--output_name', 123 type=str, 124 default='labels_softmax:0', 125 help='Name of node outputting a prediction in the model.') 126 parser.add_argument( 127 '--how_many_labels', 128 type=int, 129 default=3, 130 help='Number of results to show.') 131 132 FLAGS, unparsed = parser.parse_known_args() 133 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 134