1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /* 16 17 Tool to create accuracy statistics from running an audio recognition model on a 18 continuous stream of samples. 19 20 This is designed to be an environment for running experiments on new models and 21 settings to understand the effects they will have in a real application. You 22 need to supply it with a long audio file containing sounds you want to recognize 23 and a text file listing the labels of each sound along with the time they occur. 24 With this information, and a frozen model, the tool will process the audio 25 stream, apply the model, and keep track of how many mistakes and successes the 26 model achieved. 27 28 The matched percentage is the number of sounds that were correctly classified, 29 as a percentage of the total number of sounds listed in the ground truth file. 30 A correct classification is when the right label is chosen within a short time 31 of the expected ground truth, where the time tolerance is controlled by the 32 'time_tolerance_ms' command line flag. 33 34 The wrong percentage is how many sounds triggered a detection (the classifier 35 figured out it wasn't silence or background noise), but the detected class was 36 wrong. This is also a percentage of the total number of ground truth sounds. 37 38 The false positive percentage is how many sounds were detected when there was 39 only silence or background noise. This is also expressed as a percentage of the 40 total number of ground truth sounds, though since it can be large it may go 41 above 100%. 42 43 The easiest way to get an audio file and labels to test with is by using the 44 'generate_streaming_test_wav' script. This will synthesize a test file with 45 randomly placed sounds and background noise, and output a text file with the 46 ground truth. 47 48 If you want to test natural data, you need to use a .wav with the same sample 49 rate as your model (often 16,000 samples per second), and note down where the 50 sounds occur in time. Save this information out as a comma-separated text file, 51 where the first column is the label and the second is the time in seconds from 52 the start of the file that it occurs. 53 54 Here's an example of how to run the tool: 55 56 bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \ 57 --wav=/tmp/streaming_test_bg.wav \ 58 --graph=/tmp/conv_frozen.pb \ 59 --labels=/tmp/speech_commands_train/conv_labels.txt \ 60 --ground_truth=/tmp/streaming_test_labels.txt --verbose \ 61 --clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \ 62 --suppression_ms=500 --time_tolerance_ms=1500 63 64 */ 65 66 #include <fstream> 67 #include <iomanip> 68 #include <unordered_set> 69 #include <vector> 70 71 #include "tensorflow/core/framework/tensor.h" 72 #include "tensorflow/core/lib/io/path.h" 73 #include "tensorflow/core/lib/strings/numbers.h" 74 #include "tensorflow/core/lib/strings/str_util.h" 75 #include "tensorflow/core/lib/wav/wav_io.h" 76 #include "tensorflow/core/platform/init_main.h" 77 #include "tensorflow/core/platform/logging.h" 78 #include "tensorflow/core/platform/types.h" 79 #include "tensorflow/core/public/session.h" 80 #include "tensorflow/core/util/command_line_flags.h" 81 #include "tensorflow/examples/speech_commands/accuracy_utils.h" 82 #include "tensorflow/examples/speech_commands/recognize_commands.h" 83 84 // These are all common classes it's handy to reference with no namespace. 85 using tensorflow::Flag; 86 using tensorflow::Status; 87 using tensorflow::Tensor; 88 using tensorflow::int32; 89 using tensorflow::int64; 90 using tensorflow::string; 91 using tensorflow::uint16; 92 using tensorflow::uint32; 93 94 namespace { 95 96 // Reads a model graph definition from disk, and creates a session object you 97 // can use to run it. 98 Status LoadGraph(const string& graph_file_name, 99 std::unique_ptr<tensorflow::Session>* session) { 100 tensorflow::GraphDef graph_def; 101 Status load_graph_status = 102 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); 103 if (!load_graph_status.ok()) { 104 return tensorflow::errors::NotFound("Failed to load compute graph at '", 105 graph_file_name, "'"); 106 } 107 session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); 108 Status session_create_status = (*session)->Create(graph_def); 109 if (!session_create_status.ok()) { 110 return session_create_status; 111 } 112 return Status::OK(); 113 } 114 115 // Takes a file name, and loads a list of labels from it, one per line, and 116 // returns a vector of the strings. 117 Status ReadLabelsFile(const string& file_name, std::vector<string>* result) { 118 std::ifstream file(file_name); 119 if (!file) { 120 return tensorflow::errors::NotFound("Labels file '", file_name, 121 "' not found."); 122 } 123 result->clear(); 124 string line; 125 while (std::getline(file, line)) { 126 result->push_back(line); 127 } 128 return Status::OK(); 129 } 130 131 } // namespace 132 133 int main(int argc, char* argv[]) { 134 string wav = ""; 135 string graph = ""; 136 string labels = ""; 137 string ground_truth = ""; 138 string input_data_name = "decoded_sample_data:0"; 139 string input_rate_name = "decoded_sample_data:1"; 140 string output_name = "labels_softmax"; 141 int32 clip_duration_ms = 1000; 142 int32 clip_stride_ms = 30; 143 int32 average_window_ms = 500; 144 int32 time_tolerance_ms = 750; 145 int32 suppression_ms = 1500; 146 float detection_threshold = 0.7f; 147 bool verbose = false; 148 std::vector<Flag> flag_list = { 149 Flag("wav", &wav, "audio file to be identified"), 150 Flag("graph", &graph, "model to be executed"), 151 Flag("labels", &labels, "path to file containing labels"), 152 Flag("ground_truth", &ground_truth, 153 "path to file containing correct times and labels of words in the " 154 "audio as <word>,<timestamp in ms> lines"), 155 Flag("input_data_name", &input_data_name, 156 "name of input data node in model"), 157 Flag("input_rate_name", &input_rate_name, 158 "name of input sample rate node in model"), 159 Flag("output_name", &output_name, "name of output node in model"), 160 Flag("clip_duration_ms", &clip_duration_ms, 161 "length of recognition window"), 162 Flag("average_window_ms", &average_window_ms, 163 "length of window to smooth results over"), 164 Flag("time_tolerance_ms", &time_tolerance_ms, 165 "maximum gap allowed between a recognition and ground truth"), 166 Flag("suppression_ms", &suppression_ms, 167 "how long to ignore others for after a recognition"), 168 Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"), 169 Flag("detection_threshold", &detection_threshold, 170 "what score is required to trigger detection of a word"), 171 Flag("verbose", &verbose, "whether to log extra debugging information"), 172 }; 173 string usage = tensorflow::Flags::Usage(argv[0], flag_list); 174 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); 175 if (!parse_result) { 176 LOG(ERROR) << usage; 177 return -1; 178 } 179 180 // We need to call this to set up global state for TensorFlow. 181 tensorflow::port::InitMain(argv[0], &argc, &argv); 182 if (argc > 1) { 183 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; 184 return -1; 185 } 186 187 // First we load and initialize the model. 188 std::unique_ptr<tensorflow::Session> session; 189 Status load_graph_status = LoadGraph(graph, &session); 190 if (!load_graph_status.ok()) { 191 LOG(ERROR) << load_graph_status; 192 return -1; 193 } 194 195 std::vector<string> labels_list; 196 Status read_labels_status = ReadLabelsFile(labels, &labels_list); 197 if (!read_labels_status.ok()) { 198 LOG(ERROR) << read_labels_status; 199 return -1; 200 } 201 202 std::vector<std::pair<string, tensorflow::int64>> ground_truth_list; 203 Status read_ground_truth_status = 204 tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list); 205 if (!read_ground_truth_status.ok()) { 206 LOG(ERROR) << read_ground_truth_status; 207 return -1; 208 } 209 210 string wav_string; 211 Status read_wav_status = tensorflow::ReadFileToString( 212 tensorflow::Env::Default(), wav, &wav_string); 213 if (!read_wav_status.ok()) { 214 LOG(ERROR) << read_wav_status; 215 return -1; 216 } 217 std::vector<float> audio_data; 218 uint32 sample_count; 219 uint16 channel_count; 220 uint32 sample_rate; 221 Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector( 222 wav_string, &audio_data, &sample_count, &channel_count, &sample_rate); 223 if (!decode_wav_status.ok()) { 224 LOG(ERROR) << decode_wav_status; 225 return -1; 226 } 227 if (channel_count != 1) { 228 LOG(ERROR) << "Only mono .wav files can be used, but input has " 229 << channel_count << " channels."; 230 return -1; 231 } 232 233 const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000; 234 const int64 clip_stride_samples = (clip_stride_ms * sample_rate) / 1000; 235 Tensor audio_data_tensor(tensorflow::DT_FLOAT, 236 tensorflow::TensorShape({clip_duration_samples, 1})); 237 238 Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({})); 239 sample_rate_tensor.scalar<int32>()() = sample_rate; 240 241 tensorflow::RecognizeCommands recognize_commands( 242 labels_list, average_window_ms, detection_threshold, suppression_ms); 243 244 std::vector<std::pair<string, int64>> all_found_words; 245 tensorflow::StreamingAccuracyStats previous_stats; 246 247 const int64 audio_data_end = (sample_count - clip_duration_samples); 248 for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end; 249 audio_data_offset += clip_stride_samples) { 250 const float* input_start = &(audio_data[audio_data_offset]); 251 const float* input_end = input_start + clip_duration_samples; 252 std::copy(input_start, input_end, audio_data_tensor.flat<float>().data()); 253 254 // Actually run the audio through the model. 255 std::vector<Tensor> outputs; 256 Status run_status = session->Run({{input_data_name, audio_data_tensor}, 257 {input_rate_name, sample_rate_tensor}}, 258 {output_name}, {}, &outputs); 259 if (!run_status.ok()) { 260 LOG(ERROR) << "Running model failed: " << run_status; 261 return -1; 262 } 263 264 const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate; 265 string found_command; 266 float score; 267 bool is_new_command; 268 Status recognize_status = recognize_commands.ProcessLatestResults( 269 outputs[0], current_time_ms, &found_command, &score, &is_new_command); 270 if (!recognize_status.ok()) { 271 LOG(ERROR) << "Recognition processing failed: " << recognize_status; 272 return -1; 273 } 274 275 if (is_new_command && (found_command != "_silence_")) { 276 all_found_words.push_back({found_command, current_time_ms}); 277 if (verbose) { 278 tensorflow::StreamingAccuracyStats stats; 279 tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, 280 current_time_ms, time_tolerance_ms, 281 &stats); 282 int32 false_positive_delta = stats.how_many_false_positives - 283 previous_stats.how_many_false_positives; 284 int32 correct_delta = stats.how_many_correct_words - 285 previous_stats.how_many_correct_words; 286 int32 wrong_delta = 287 stats.how_many_wrong_words - previous_stats.how_many_wrong_words; 288 string recognition_state; 289 if (false_positive_delta == 1) { 290 recognition_state = " (False Positive)"; 291 } else if (correct_delta == 1) { 292 recognition_state = " (Correct)"; 293 } else if (wrong_delta == 1) { 294 recognition_state = " (Wrong)"; 295 } else { 296 LOG(ERROR) << "Unexpected state in statistics"; 297 } 298 LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score 299 << recognition_state; 300 previous_stats = stats; 301 tensorflow::PrintAccuracyStats(stats); 302 } 303 } 304 } 305 306 tensorflow::StreamingAccuracyStats stats; 307 tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1, 308 time_tolerance_ms, &stats); 309 tensorflow::PrintAccuracyStats(stats); 310 311 return 0; 312 } 313