Home | History | Annotate | Download | only in demo
      1 /*
      2  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *       http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.tensorflow.demo;
     18 
     19 import android.util.Log;
     20 import android.util.Pair;
     21 import java.util.ArrayDeque;
     22 import java.util.ArrayList;
     23 import java.util.Arrays;
     24 import java.util.Deque;
     25 import java.util.List;
     26 
     27 /** Reads in results from an instantaneous audio recognition model and smoothes them over time. */
     28 public class RecognizeCommands {
     29   // Configuration settings.
     30   private List<String> labels = new ArrayList<String>();
     31   private long averageWindowDurationMs;
     32   private float detectionThreshold;
     33   private int suppressionMs;
     34   private int minimumCount;
     35   private long minimumTimeBetweenSamplesMs;
     36 
     37   // Working variables.
     38   private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>();
     39   private String previousTopLabel;
     40   private int labelsCount;
     41   private long previousTopLabelTime;
     42   private float previousTopLabelScore;
     43 
     44   private static final String SILENCE_LABEL = "_silence_";
     45   private static final long MINIMUM_TIME_FRACTION = 4;
     46 
     47   public RecognizeCommands(
     48       List<String> inLabels,
     49       long inAverageWindowDurationMs,
     50       float inDetectionThreshold,
     51       int inSuppressionMS,
     52       int inMinimumCount,
     53       long inMinimumTimeBetweenSamplesMS) {
     54     labels = inLabels;
     55     averageWindowDurationMs = inAverageWindowDurationMs;
     56     detectionThreshold = inDetectionThreshold;
     57     suppressionMs = inSuppressionMS;
     58     minimumCount = inMinimumCount;
     59     labelsCount = inLabels.size();
     60     previousTopLabel = SILENCE_LABEL;
     61     previousTopLabelTime = Long.MIN_VALUE;
     62     previousTopLabelScore = 0.0f;
     63     minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS;
     64   }
     65 
     66   /** Holds information about what's been recognized. */
     67   public static class RecognitionResult {
     68     public final String foundCommand;
     69     public final float score;
     70     public final boolean isNewCommand;
     71 
     72     public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) {
     73       foundCommand = inFoundCommand;
     74       score = inScore;
     75       isNewCommand = inIsNewCommand;
     76     }
     77   }
     78 
     79   private static class ScoreForSorting implements Comparable<ScoreForSorting> {
     80     public final float score;
     81     public final int index;
     82 
     83     public ScoreForSorting(float inScore, int inIndex) {
     84       score = inScore;
     85       index = inIndex;
     86     }
     87 
     88     @Override
     89     public int compareTo(ScoreForSorting other) {
     90       if (this.score > other.score) {
     91         return -1;
     92       } else if (this.score < other.score) {
     93         return 1;
     94       } else {
     95         return 0;
     96       }
     97     }
     98   }
     99 
    100   public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) {
    101     if (currentResults.length != labelsCount) {
    102       throw new RuntimeException(
    103           "The results for recognition should contain "
    104               + labelsCount
    105               + " elements, but there are "
    106               + currentResults.length);
    107     }
    108 
    109     if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) {
    110       throw new RuntimeException(
    111           "You must feed results in increasing time order, but received a timestamp of "
    112               + currentTimeMS
    113               + " that was earlier than the previous one of "
    114               + previousResults.getFirst().first);
    115     }
    116 
    117     final int howManyResults = previousResults.size();
    118     // Ignore any results that are coming in too frequently.
    119     if (howManyResults > 1) {
    120       final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first;
    121       if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) {
    122         return new RecognitionResult(previousTopLabel, previousTopLabelScore, false);
    123       }
    124     }
    125 
    126     // Add the latest results to the head of the queue.
    127     previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults));
    128 
    129     // Prune any earlier results that are too old for the averaging window.
    130     final long timeLimit = currentTimeMS - averageWindowDurationMs;
    131     while (previousResults.getFirst().first < timeLimit) {
    132       previousResults.removeFirst();
    133     }
    134 
    135     // If there are too few results, assume the result will be unreliable and
    136     // bail.
    137     final long earliestTime = previousResults.getFirst().first;
    138     final long samplesDuration = currentTimeMS - earliestTime;
    139     if ((howManyResults < minimumCount)
    140         || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) {
    141       Log.v("RecognizeResult", "Too few results");
    142       return new RecognitionResult(previousTopLabel, 0.0f, false);
    143     }
    144 
    145     // Calculate the average score across all the results in the window.
    146     float[] averageScores = new float[labelsCount];
    147     for (Pair<Long, float[]> previousResult : previousResults) {
    148       final float[] scoresTensor = previousResult.second;
    149       int i = 0;
    150       while (i < scoresTensor.length) {
    151         averageScores[i] += scoresTensor[i] / howManyResults;
    152         ++i;
    153       }
    154     }
    155 
    156     // Sort the averaged results in descending score order.
    157     ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount];
    158     for (int i = 0; i < labelsCount; ++i) {
    159       sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i);
    160     }
    161     Arrays.sort(sortedAverageScores);
    162 
    163     // See if the latest top score is enough to trigger a detection.
    164     final int currentTopIndex = sortedAverageScores[0].index;
    165     final String currentTopLabel = labels.get(currentTopIndex);
    166     final float currentTopScore = sortedAverageScores[0].score;
    167     // If we've recently had another label trigger, assume one that occurs too
    168     // soon afterwards is a bad result.
    169     long timeSinceLastTop;
    170     if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) {
    171       timeSinceLastTop = Long.MAX_VALUE;
    172     } else {
    173       timeSinceLastTop = currentTimeMS - previousTopLabelTime;
    174     }
    175     boolean isNewCommand;
    176     if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) {
    177       previousTopLabel = currentTopLabel;
    178       previousTopLabelTime = currentTimeMS;
    179       previousTopLabelScore = currentTopScore;
    180       isNewCommand = true;
    181     } else {
    182       isNewCommand = false;
    183     }
    184     return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand);
    185   }
    186 }
    187