1 /* 2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.tensorflow.demo; 18 19 import android.util.Log; 20 import android.util.Pair; 21 import java.util.ArrayDeque; 22 import java.util.ArrayList; 23 import java.util.Arrays; 24 import java.util.Deque; 25 import java.util.List; 26 27 /** Reads in results from an instantaneous audio recognition model and smoothes them over time. */ 28 public class RecognizeCommands { 29 // Configuration settings. 30 private List<String> labels = new ArrayList<String>(); 31 private long averageWindowDurationMs; 32 private float detectionThreshold; 33 private int suppressionMs; 34 private int minimumCount; 35 private long minimumTimeBetweenSamplesMs; 36 37 // Working variables. 38 private Deque<Pair<Long, float[]>> previousResults = new ArrayDeque<Pair<Long, float[]>>(); 39 private String previousTopLabel; 40 private int labelsCount; 41 private long previousTopLabelTime; 42 private float previousTopLabelScore; 43 44 private static final String SILENCE_LABEL = "_silence_"; 45 private static final long MINIMUM_TIME_FRACTION = 4; 46 47 public RecognizeCommands( 48 List<String> inLabels, 49 long inAverageWindowDurationMs, 50 float inDetectionThreshold, 51 int inSuppressionMS, 52 int inMinimumCount, 53 long inMinimumTimeBetweenSamplesMS) { 54 labels = inLabels; 55 averageWindowDurationMs = inAverageWindowDurationMs; 56 detectionThreshold = inDetectionThreshold; 57 suppressionMs = inSuppressionMS; 58 minimumCount = inMinimumCount; 59 labelsCount = inLabels.size(); 60 previousTopLabel = SILENCE_LABEL; 61 previousTopLabelTime = Long.MIN_VALUE; 62 previousTopLabelScore = 0.0f; 63 minimumTimeBetweenSamplesMs = inMinimumTimeBetweenSamplesMS; 64 } 65 66 /** Holds information about what's been recognized. */ 67 public static class RecognitionResult { 68 public final String foundCommand; 69 public final float score; 70 public final boolean isNewCommand; 71 72 public RecognitionResult(String inFoundCommand, float inScore, boolean inIsNewCommand) { 73 foundCommand = inFoundCommand; 74 score = inScore; 75 isNewCommand = inIsNewCommand; 76 } 77 } 78 79 private static class ScoreForSorting implements Comparable<ScoreForSorting> { 80 public final float score; 81 public final int index; 82 83 public ScoreForSorting(float inScore, int inIndex) { 84 score = inScore; 85 index = inIndex; 86 } 87 88 @Override 89 public int compareTo(ScoreForSorting other) { 90 if (this.score > other.score) { 91 return -1; 92 } else if (this.score < other.score) { 93 return 1; 94 } else { 95 return 0; 96 } 97 } 98 } 99 100 public RecognitionResult processLatestResults(float[] currentResults, long currentTimeMS) { 101 if (currentResults.length != labelsCount) { 102 throw new RuntimeException( 103 "The results for recognition should contain " 104 + labelsCount 105 + " elements, but there are " 106 + currentResults.length); 107 } 108 109 if ((!previousResults.isEmpty()) && (currentTimeMS < previousResults.getFirst().first)) { 110 throw new RuntimeException( 111 "You must feed results in increasing time order, but received a timestamp of " 112 + currentTimeMS 113 + " that was earlier than the previous one of " 114 + previousResults.getFirst().first); 115 } 116 117 final int howManyResults = previousResults.size(); 118 // Ignore any results that are coming in too frequently. 119 if (howManyResults > 1) { 120 final long timeSinceMostRecent = currentTimeMS - previousResults.getLast().first; 121 if (timeSinceMostRecent < minimumTimeBetweenSamplesMs) { 122 return new RecognitionResult(previousTopLabel, previousTopLabelScore, false); 123 } 124 } 125 126 // Add the latest results to the head of the queue. 127 previousResults.addLast(new Pair<Long, float[]>(currentTimeMS, currentResults)); 128 129 // Prune any earlier results that are too old for the averaging window. 130 final long timeLimit = currentTimeMS - averageWindowDurationMs; 131 while (previousResults.getFirst().first < timeLimit) { 132 previousResults.removeFirst(); 133 } 134 135 // If there are too few results, assume the result will be unreliable and 136 // bail. 137 final long earliestTime = previousResults.getFirst().first; 138 final long samplesDuration = currentTimeMS - earliestTime; 139 if ((howManyResults < minimumCount) 140 || (samplesDuration < (averageWindowDurationMs / MINIMUM_TIME_FRACTION))) { 141 Log.v("RecognizeResult", "Too few results"); 142 return new RecognitionResult(previousTopLabel, 0.0f, false); 143 } 144 145 // Calculate the average score across all the results in the window. 146 float[] averageScores = new float[labelsCount]; 147 for (Pair<Long, float[]> previousResult : previousResults) { 148 final float[] scoresTensor = previousResult.second; 149 int i = 0; 150 while (i < scoresTensor.length) { 151 averageScores[i] += scoresTensor[i] / howManyResults; 152 ++i; 153 } 154 } 155 156 // Sort the averaged results in descending score order. 157 ScoreForSorting[] sortedAverageScores = new ScoreForSorting[labelsCount]; 158 for (int i = 0; i < labelsCount; ++i) { 159 sortedAverageScores[i] = new ScoreForSorting(averageScores[i], i); 160 } 161 Arrays.sort(sortedAverageScores); 162 163 // See if the latest top score is enough to trigger a detection. 164 final int currentTopIndex = sortedAverageScores[0].index; 165 final String currentTopLabel = labels.get(currentTopIndex); 166 final float currentTopScore = sortedAverageScores[0].score; 167 // If we've recently had another label trigger, assume one that occurs too 168 // soon afterwards is a bad result. 169 long timeSinceLastTop; 170 if (previousTopLabel.equals(SILENCE_LABEL) || (previousTopLabelTime == Long.MIN_VALUE)) { 171 timeSinceLastTop = Long.MAX_VALUE; 172 } else { 173 timeSinceLastTop = currentTimeMS - previousTopLabelTime; 174 } 175 boolean isNewCommand; 176 if ((currentTopScore > detectionThreshold) && (timeSinceLastTop > suppressionMs)) { 177 previousTopLabel = currentTopLabel; 178 previousTopLabelTime = currentTimeMS; 179 previousTopLabelScore = currentTopScore; 180 isNewCommand = true; 181 } else { 182 isNewCommand = false; 183 } 184 return new RecognitionResult(currentTopLabel, currentTopScore, isNewCommand); 185 } 186 } 187