Home | History | Annotate | Download | only in speech_commands
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/examples/speech_commands/accuracy_utils.h"
     17 
     18 #include <fstream>
     19 #include <iomanip>
     20 #include <unordered_set>
     21 
     22 #include "tensorflow/core/lib/io/path.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 
     25 namespace tensorflow {
     26 
     27 Status ReadGroundTruthFile(const string& file_name,
     28                            std::vector<std::pair<string, int64>>* result) {
     29   std::ifstream file(file_name);
     30   if (!file) {
     31     return tensorflow::errors::NotFound("Ground truth file '", file_name,
     32                                         "' not found.");
     33   }
     34   result->clear();
     35   string line;
     36   while (std::getline(file, line)) {
     37     std::vector<string> pieces = tensorflow::str_util::Split(line, ',');
     38     if (pieces.size() != 2) {
     39       continue;
     40     }
     41     float timestamp;
     42     if (!tensorflow::strings::safe_strtof(pieces[1].c_str(), &timestamp)) {
     43       return tensorflow::errors::InvalidArgument(
     44           "Wrong number format at line: ", line);
     45     }
     46     string label = pieces[0];
     47     auto timestamp_int64 = static_cast<int64>(timestamp);
     48     result->push_back({label, timestamp_int64});
     49   }
     50   std::sort(result->begin(), result->end(),
     51             [](const std::pair<string, int64>& left,
     52                const std::pair<string, int64>& right) {
     53               return left.second < right.second;
     54             });
     55   return Status::OK();
     56 }
     57 
     58 void CalculateAccuracyStats(
     59     const std::vector<std::pair<string, int64>>& ground_truth_list,
     60     const std::vector<std::pair<string, int64>>& found_words,
     61     int64 up_to_time_ms, int64 time_tolerance_ms,
     62     StreamingAccuracyStats* stats) {
     63   int64 latest_possible_time;
     64   if (up_to_time_ms == -1) {
     65     latest_possible_time = std::numeric_limits<int64>::max();
     66   } else {
     67     latest_possible_time = up_to_time_ms + time_tolerance_ms;
     68   }
     69   stats->how_many_ground_truth_words = 0;
     70   for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
     71     const int64 ground_truth_time = ground_truth.second;
     72     if (ground_truth_time > latest_possible_time) {
     73       break;
     74     }
     75     ++stats->how_many_ground_truth_words;
     76   }
     77 
     78   stats->how_many_false_positives = 0;
     79   stats->how_many_correct_words = 0;
     80   stats->how_many_wrong_words = 0;
     81   std::unordered_set<int64> has_ground_truth_been_matched;
     82   for (const std::pair<string, int64>& found_word : found_words) {
     83     const string& found_label = found_word.first;
     84     const int64 found_time = found_word.second;
     85     const int64 earliest_time = found_time - time_tolerance_ms;
     86     const int64 latest_time = found_time + time_tolerance_ms;
     87     bool has_match_been_found = false;
     88     for (const std::pair<string, int64>& ground_truth : ground_truth_list) {
     89       const int64 ground_truth_time = ground_truth.second;
     90       if ((ground_truth_time > latest_time) ||
     91           (ground_truth_time > latest_possible_time)) {
     92         break;
     93       }
     94       if (ground_truth_time < earliest_time) {
     95         continue;
     96       }
     97       const string& ground_truth_label = ground_truth.first;
     98       if ((ground_truth_label == found_label) &&
     99           (has_ground_truth_been_matched.count(ground_truth_time) == 0)) {
    100         ++stats->how_many_correct_words;
    101       } else {
    102         ++stats->how_many_wrong_words;
    103       }
    104       has_ground_truth_been_matched.insert(ground_truth_time);
    105       has_match_been_found = true;
    106       break;
    107     }
    108     if (!has_match_been_found) {
    109       ++stats->how_many_false_positives;
    110     }
    111   }
    112   stats->how_many_ground_truth_matched = has_ground_truth_been_matched.size();
    113 }
    114 
    115 void PrintAccuracyStats(const StreamingAccuracyStats& stats) {
    116   if (stats.how_many_ground_truth_words == 0) {
    117     LOG(INFO) << "No ground truth yet, " << stats.how_many_false_positives
    118               << " false positives";
    119   } else {
    120     float any_match_percentage =
    121         (stats.how_many_ground_truth_matched * 100.0f) /
    122         stats.how_many_ground_truth_words;
    123     float correct_match_percentage = (stats.how_many_correct_words * 100.0f) /
    124                                      stats.how_many_ground_truth_words;
    125     float wrong_match_percentage = (stats.how_many_wrong_words * 100.0f) /
    126                                    stats.how_many_ground_truth_words;
    127     float false_positive_percentage =
    128         (stats.how_many_false_positives * 100.0f) /
    129         stats.how_many_ground_truth_words;
    130 
    131     LOG(INFO) << std::setprecision(1) << std::fixed << any_match_percentage
    132               << "% matched, " << correct_match_percentage << "% correctly, "
    133               << wrong_match_percentage << "% wrongly, "
    134               << false_positive_percentage << "% false positives ";
    135   }
    136 }
    137 
    138 }  // namespace tensorflow
    139