Home | History | Annotate | Download | only in speech_commands
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 /*
     16 
     17 Tool to create accuracy statistics from running an audio recognition model on a
     18 continuous stream of samples.
     19 
     20 This is designed to be an environment for running experiments on new models and
     21 settings to understand the effects they will have in a real application. You
     22 need to supply it with a long audio file containing sounds you want to recognize
     23 and a text file listing the labels of each sound along with the time they occur.
     24 With this information, and a frozen model, the tool will process the audio
     25 stream, apply the model, and keep track of how many mistakes and successes the
     26 model achieved.
     27 
     28 The matched percentage is the number of sounds that were correctly classified,
     29 as a percentage of the total number of sounds listed in the ground truth file.
     30 A correct classification is when the right label is chosen within a short time
     31 of the expected ground truth, where the time tolerance is controlled by the
     32 'time_tolerance_ms' command line flag.
     33 
     34 The wrong percentage is how many sounds triggered a detection (the classifier
     35 figured out it wasn't silence or background noise), but the detected class was
     36 wrong. This is also a percentage of the total number of ground truth sounds.
     37 
     38 The false positive percentage is how many sounds were detected when there was
     39 only silence or background noise. This is also expressed as a percentage of the
     40 total number of ground truth sounds, though since it can be large it may go
     41 above 100%.
     42 
     43 The easiest way to get an audio file and labels to test with is by using the
     44 'generate_streaming_test_wav' script. This will synthesize a test file with
     45 randomly placed sounds and background noise, and output a text file with the
     46 ground truth.
     47 
     48 If you want to test natural data, you need to use a .wav with the same sample
     49 rate as your model (often 16,000 samples per second), and note down where the
     50 sounds occur in time. Save this information out as a comma-separated text file,
     51 where the first column is the label and the second is the time in seconds from
     52 the start of the file that it occurs.
     53 
     54 Here's an example of how to run the tool:
     55 
     56 bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
     57 --wav=/tmp/streaming_test_bg.wav \
     58 --graph=/tmp/conv_frozen.pb \
     59 --labels=/tmp/speech_commands_train/conv_labels.txt \
     60 --ground_truth=/tmp/streaming_test_labels.txt --verbose \
     61 --clip_duration_ms=1000 --detection_threshold=0.70 --average_window_ms=500 \
     62 --suppression_ms=500 --time_tolerance_ms=1500
     63 
     64  */
     65 
     66 #include <fstream>
     67 #include <iomanip>
     68 #include <unordered_set>
     69 #include <vector>
     70 
     71 #include "tensorflow/core/framework/tensor.h"
     72 #include "tensorflow/core/lib/io/path.h"
     73 #include "tensorflow/core/lib/strings/numbers.h"
     74 #include "tensorflow/core/lib/strings/str_util.h"
     75 #include "tensorflow/core/lib/wav/wav_io.h"
     76 #include "tensorflow/core/platform/init_main.h"
     77 #include "tensorflow/core/platform/logging.h"
     78 #include "tensorflow/core/platform/types.h"
     79 #include "tensorflow/core/public/session.h"
     80 #include "tensorflow/core/util/command_line_flags.h"
     81 #include "tensorflow/examples/speech_commands/accuracy_utils.h"
     82 #include "tensorflow/examples/speech_commands/recognize_commands.h"
     83 
     84 // These are all common classes it's handy to reference with no namespace.
     85 using tensorflow::Flag;
     86 using tensorflow::Status;
     87 using tensorflow::Tensor;
     88 using tensorflow::int32;
     89 using tensorflow::int64;
     90 using tensorflow::string;
     91 using tensorflow::uint16;
     92 using tensorflow::uint32;
     93 
     94 namespace {
     95 
     96 // Reads a model graph definition from disk, and creates a session object you
     97 // can use to run it.
     98 Status LoadGraph(const string& graph_file_name,
     99                  std::unique_ptr<tensorflow::Session>* session) {
    100   tensorflow::GraphDef graph_def;
    101   Status load_graph_status =
    102       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
    103   if (!load_graph_status.ok()) {
    104     return tensorflow::errors::NotFound("Failed to load compute graph at '",
    105                                         graph_file_name, "'");
    106   }
    107   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
    108   Status session_create_status = (*session)->Create(graph_def);
    109   if (!session_create_status.ok()) {
    110     return session_create_status;
    111   }
    112   return Status::OK();
    113 }
    114 
    115 // Takes a file name, and loads a list of labels from it, one per line, and
    116 // returns a vector of the strings.
    117 Status ReadLabelsFile(const string& file_name, std::vector<string>* result) {
    118   std::ifstream file(file_name);
    119   if (!file) {
    120     return tensorflow::errors::NotFound("Labels file '", file_name,
    121                                         "' not found.");
    122   }
    123   result->clear();
    124   string line;
    125   while (std::getline(file, line)) {
    126     result->push_back(line);
    127   }
    128   return Status::OK();
    129 }
    130 
    131 }  // namespace
    132 
    133 int main(int argc, char* argv[]) {
    134   string wav = "";
    135   string graph = "";
    136   string labels = "";
    137   string ground_truth = "";
    138   string input_data_name = "decoded_sample_data:0";
    139   string input_rate_name = "decoded_sample_data:1";
    140   string output_name = "labels_softmax";
    141   int32 clip_duration_ms = 1000;
    142   int32 clip_stride_ms = 30;
    143   int32 average_window_ms = 500;
    144   int32 time_tolerance_ms = 750;
    145   int32 suppression_ms = 1500;
    146   float detection_threshold = 0.7f;
    147   bool verbose = false;
    148   std::vector<Flag> flag_list = {
    149       Flag("wav", &wav, "audio file to be identified"),
    150       Flag("graph", &graph, "model to be executed"),
    151       Flag("labels", &labels, "path to file containing labels"),
    152       Flag("ground_truth", &ground_truth,
    153            "path to file containing correct times and labels of words in the "
    154            "audio as <word>,<timestamp in ms> lines"),
    155       Flag("input_data_name", &input_data_name,
    156            "name of input data node in model"),
    157       Flag("input_rate_name", &input_rate_name,
    158            "name of input sample rate node in model"),
    159       Flag("output_name", &output_name, "name of output node in model"),
    160       Flag("clip_duration_ms", &clip_duration_ms,
    161            "length of recognition window"),
    162       Flag("average_window_ms", &average_window_ms,
    163            "length of window to smooth results over"),
    164       Flag("time_tolerance_ms", &time_tolerance_ms,
    165            "maximum gap allowed between a recognition and ground truth"),
    166       Flag("suppression_ms", &suppression_ms,
    167            "how long to ignore others for after a recognition"),
    168       Flag("clip_stride_ms", &clip_stride_ms, "how often to run recognition"),
    169       Flag("detection_threshold", &detection_threshold,
    170            "what score is required to trigger detection of a word"),
    171       Flag("verbose", &verbose, "whether to log extra debugging information"),
    172   };
    173   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    174   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
    175   if (!parse_result) {
    176     LOG(ERROR) << usage;
    177     return -1;
    178   }
    179 
    180   // We need to call this to set up global state for TensorFlow.
    181   tensorflow::port::InitMain(argv[0], &argc, &argv);
    182   if (argc > 1) {
    183     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
    184     return -1;
    185   }
    186 
    187   // First we load and initialize the model.
    188   std::unique_ptr<tensorflow::Session> session;
    189   Status load_graph_status = LoadGraph(graph, &session);
    190   if (!load_graph_status.ok()) {
    191     LOG(ERROR) << load_graph_status;
    192     return -1;
    193   }
    194 
    195   std::vector<string> labels_list;
    196   Status read_labels_status = ReadLabelsFile(labels, &labels_list);
    197   if (!read_labels_status.ok()) {
    198     LOG(ERROR) << read_labels_status;
    199     return -1;
    200   }
    201 
    202   std::vector<std::pair<string, tensorflow::int64>> ground_truth_list;
    203   Status read_ground_truth_status =
    204       tensorflow::ReadGroundTruthFile(ground_truth, &ground_truth_list);
    205   if (!read_ground_truth_status.ok()) {
    206     LOG(ERROR) << read_ground_truth_status;
    207     return -1;
    208   }
    209 
    210   string wav_string;
    211   Status read_wav_status = tensorflow::ReadFileToString(
    212       tensorflow::Env::Default(), wav, &wav_string);
    213   if (!read_wav_status.ok()) {
    214     LOG(ERROR) << read_wav_status;
    215     return -1;
    216   }
    217   std::vector<float> audio_data;
    218   uint32 sample_count;
    219   uint16 channel_count;
    220   uint32 sample_rate;
    221   Status decode_wav_status = tensorflow::wav::DecodeLin16WaveAsFloatVector(
    222       wav_string, &audio_data, &sample_count, &channel_count, &sample_rate);
    223   if (!decode_wav_status.ok()) {
    224     LOG(ERROR) << decode_wav_status;
    225     return -1;
    226   }
    227   if (channel_count != 1) {
    228     LOG(ERROR) << "Only mono .wav files can be used, but input has "
    229                << channel_count << " channels.";
    230     return -1;
    231   }
    232 
    233   const int64 clip_duration_samples = (clip_duration_ms * sample_rate) / 1000;
    234   const int64 clip_stride_samples = (clip_stride_ms * sample_rate) / 1000;
    235   Tensor audio_data_tensor(tensorflow::DT_FLOAT,
    236                            tensorflow::TensorShape({clip_duration_samples, 1}));
    237 
    238   Tensor sample_rate_tensor(tensorflow::DT_INT32, tensorflow::TensorShape({}));
    239   sample_rate_tensor.scalar<int32>()() = sample_rate;
    240 
    241   tensorflow::RecognizeCommands recognize_commands(
    242       labels_list, average_window_ms, detection_threshold, suppression_ms);
    243 
    244   std::vector<std::pair<string, int64>> all_found_words;
    245   tensorflow::StreamingAccuracyStats previous_stats;
    246 
    247   const int64 audio_data_end = (sample_count - clip_duration_samples);
    248   for (int64 audio_data_offset = 0; audio_data_offset < audio_data_end;
    249        audio_data_offset += clip_stride_samples) {
    250     const float* input_start = &(audio_data[audio_data_offset]);
    251     const float* input_end = input_start + clip_duration_samples;
    252     std::copy(input_start, input_end, audio_data_tensor.flat<float>().data());
    253 
    254     // Actually run the audio through the model.
    255     std::vector<Tensor> outputs;
    256     Status run_status = session->Run({{input_data_name, audio_data_tensor},
    257                                       {input_rate_name, sample_rate_tensor}},
    258                                      {output_name}, {}, &outputs);
    259     if (!run_status.ok()) {
    260       LOG(ERROR) << "Running model failed: " << run_status;
    261       return -1;
    262     }
    263 
    264     const int64 current_time_ms = (audio_data_offset * 1000) / sample_rate;
    265     string found_command;
    266     float score;
    267     bool is_new_command;
    268     Status recognize_status = recognize_commands.ProcessLatestResults(
    269         outputs[0], current_time_ms, &found_command, &score, &is_new_command);
    270     if (!recognize_status.ok()) {
    271       LOG(ERROR) << "Recognition processing failed: " << recognize_status;
    272       return -1;
    273     }
    274 
    275     if (is_new_command && (found_command != "_silence_")) {
    276       all_found_words.push_back({found_command, current_time_ms});
    277       if (verbose) {
    278         tensorflow::StreamingAccuracyStats stats;
    279         tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words,
    280                                            current_time_ms, time_tolerance_ms,
    281                                            &stats);
    282         int32 false_positive_delta = stats.how_many_false_positives -
    283                                      previous_stats.how_many_false_positives;
    284         int32 correct_delta = stats.how_many_correct_words -
    285                               previous_stats.how_many_correct_words;
    286         int32 wrong_delta =
    287             stats.how_many_wrong_words - previous_stats.how_many_wrong_words;
    288         string recognition_state;
    289         if (false_positive_delta == 1) {
    290           recognition_state = " (False Positive)";
    291         } else if (correct_delta == 1) {
    292           recognition_state = " (Correct)";
    293         } else if (wrong_delta == 1) {
    294           recognition_state = " (Wrong)";
    295         } else {
    296           LOG(ERROR) << "Unexpected state in statistics";
    297         }
    298         LOG(INFO) << current_time_ms << "ms: " << found_command << ": " << score
    299                   << recognition_state;
    300         previous_stats = stats;
    301         tensorflow::PrintAccuracyStats(stats);
    302       }
    303     }
    304   }
    305 
    306   tensorflow::StreamingAccuracyStats stats;
    307   tensorflow::CalculateAccuracyStats(ground_truth_list, all_found_words, -1,
    308                                      time_tolerance_ms, &stats);
    309   tensorflow::PrintAccuracyStats(stats);
    310 
    311   return 0;
    312 }
    313