1 /* 2 * Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 /* Demonstrates how to run an audio recognition model in Android. 18 19 This example loads a simple speech recognition model trained by the tutorial at 20 https://www.tensorflow.org/tutorials/audio_training 21 22 The model files should be downloaded automatically from the TensorFlow website, 23 but if you have a custom model you can update the LABEL_FILENAME and 24 MODEL_FILENAME constants to point to your own files. 25 26 The example application displays a list view with all of the known audio labels, 27 and highlights each one when it thinks it has detected one through the 28 microphone. The averaging of results to give a more reliable signal happens in 29 the RecognizeCommands helper class. 30 */ 31 32 package org.tensorflow.demo; 33 34 import android.animation.AnimatorInflater; 35 import android.animation.AnimatorSet; 36 import android.app.Activity; 37 import android.content.pm.PackageManager; 38 import android.media.AudioFormat; 39 import android.media.AudioRecord; 40 import android.media.MediaRecorder; 41 import android.os.Build; 42 import android.os.Bundle; 43 import android.util.Log; 44 import android.view.View; 45 import android.widget.ArrayAdapter; 46 import android.widget.Button; 47 import android.widget.ListView; 48 import java.io.BufferedReader; 49 import java.io.IOException; 50 import java.io.InputStreamReader; 51 import java.util.ArrayList; 52 import java.util.List; 53 import java.util.concurrent.locks.ReentrantLock; 54 import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 55 import org.tensorflow.demo.R; 56 57 /** 58 * An activity that listens for audio and then uses a TensorFlow model to detect particular classes, 59 * by default a small set of action words. 60 */ 61 public class SpeechActivity extends Activity { 62 63 // Constants that control the behavior of the recognition code and model 64 // settings. See the audio recognition tutorial for a detailed explanation of 65 // all these, but you should customize them to match your training settings if 66 // you are running your own model. 67 private static final int SAMPLE_RATE = 16000; 68 private static final int SAMPLE_DURATION_MS = 1000; 69 private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000); 70 private static final long AVERAGE_WINDOW_DURATION_MS = 500; 71 private static final float DETECTION_THRESHOLD = 0.70f; 72 private static final int SUPPRESSION_MS = 1500; 73 private static final int MINIMUM_COUNT = 3; 74 private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30; 75 private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt"; 76 private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb"; 77 private static final String INPUT_DATA_NAME = "decoded_sample_data:0"; 78 private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1"; 79 private static final String OUTPUT_SCORES_NAME = "labels_softmax"; 80 81 // UI elements. 82 private static final int REQUEST_RECORD_AUDIO = 13; 83 private Button quitButton; 84 private ListView labelsListView; 85 private static final String LOG_TAG = SpeechActivity.class.getSimpleName(); 86 87 // Working variables. 88 short[] recordingBuffer = new short[RECORDING_LENGTH]; 89 int recordingOffset = 0; 90 boolean shouldContinue = true; 91 private Thread recordingThread; 92 boolean shouldContinueRecognition = true; 93 private Thread recognitionThread; 94 private final ReentrantLock recordingBufferLock = new ReentrantLock(); 95 private TensorFlowInferenceInterface inferenceInterface; 96 private List<String> labels = new ArrayList<String>(); 97 private List<String> displayedLabels = new ArrayList<>(); 98 private RecognizeCommands recognizeCommands = null; 99 100 @Override 101 protected void onCreate(Bundle savedInstanceState) { 102 // Set up the UI. 103 super.onCreate(savedInstanceState); 104 setContentView(R.layout.activity_speech); 105 quitButton = (Button) findViewById(R.id.quit); 106 quitButton.setOnClickListener( 107 new View.OnClickListener() { 108 @Override 109 public void onClick(View view) { 110 moveTaskToBack(true); 111 android.os.Process.killProcess(android.os.Process.myPid()); 112 System.exit(1); 113 } 114 }); 115 labelsListView = (ListView) findViewById(R.id.list_view); 116 117 // Load the labels for the model, but only display those that don't start 118 // with an underscore. 119 String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1]; 120 Log.i(LOG_TAG, "Reading labels from: " + actualFilename); 121 BufferedReader br = null; 122 try { 123 br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename))); 124 String line; 125 while ((line = br.readLine()) != null) { 126 labels.add(line); 127 if (line.charAt(0) != '_') { 128 displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1)); 129 } 130 } 131 br.close(); 132 } catch (IOException e) { 133 throw new RuntimeException("Problem reading label file!", e); 134 } 135 136 // Build a list view based on these labels. 137 ArrayAdapter<String> arrayAdapter = 138 new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels); 139 labelsListView.setAdapter(arrayAdapter); 140 141 // Set up an object to smooth recognition results to increase accuracy. 142 recognizeCommands = 143 new RecognizeCommands( 144 labels, 145 AVERAGE_WINDOW_DURATION_MS, 146 DETECTION_THRESHOLD, 147 SUPPRESSION_MS, 148 MINIMUM_COUNT, 149 MINIMUM_TIME_BETWEEN_SAMPLES_MS); 150 151 // Load the TensorFlow model. 152 inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME); 153 154 // Start the recording and recognition threads. 155 requestMicrophonePermission(); 156 startRecording(); 157 startRecognition(); 158 } 159 160 private void requestMicrophonePermission() { 161 if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { 162 requestPermissions( 163 new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO); 164 } 165 } 166 167 @Override 168 public void onRequestPermissionsResult( 169 int requestCode, String[] permissions, int[] grantResults) { 170 if (requestCode == REQUEST_RECORD_AUDIO 171 && grantResults.length > 0 172 && grantResults[0] == PackageManager.PERMISSION_GRANTED) { 173 startRecording(); 174 startRecognition(); 175 } 176 } 177 178 public synchronized void startRecording() { 179 if (recordingThread != null) { 180 return; 181 } 182 shouldContinue = true; 183 recordingThread = 184 new Thread( 185 new Runnable() { 186 @Override 187 public void run() { 188 record(); 189 } 190 }); 191 recordingThread.start(); 192 } 193 194 public synchronized void stopRecording() { 195 if (recordingThread == null) { 196 return; 197 } 198 shouldContinue = false; 199 recordingThread = null; 200 } 201 202 private void record() { 203 android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO); 204 205 // Estimate the buffer size we'll need for this device. 206 int bufferSize = 207 AudioRecord.getMinBufferSize( 208 SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT); 209 if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) { 210 bufferSize = SAMPLE_RATE * 2; 211 } 212 short[] audioBuffer = new short[bufferSize / 2]; 213 214 AudioRecord record = 215 new AudioRecord( 216 MediaRecorder.AudioSource.DEFAULT, 217 SAMPLE_RATE, 218 AudioFormat.CHANNEL_IN_MONO, 219 AudioFormat.ENCODING_PCM_16BIT, 220 bufferSize); 221 222 if (record.getState() != AudioRecord.STATE_INITIALIZED) { 223 Log.e(LOG_TAG, "Audio Record can't initialize!"); 224 return; 225 } 226 227 record.startRecording(); 228 229 Log.v(LOG_TAG, "Start recording"); 230 231 // Loop, gathering audio data and copying it to a round-robin buffer. 232 while (shouldContinue) { 233 int numberRead = record.read(audioBuffer, 0, audioBuffer.length); 234 int maxLength = recordingBuffer.length; 235 int newRecordingOffset = recordingOffset + numberRead; 236 int secondCopyLength = Math.max(0, newRecordingOffset - maxLength); 237 int firstCopyLength = numberRead - secondCopyLength; 238 // We store off all the data for the recognition thread to access. The ML 239 // thread will copy out of this buffer into its own, while holding the 240 // lock, so this should be thread safe. 241 recordingBufferLock.lock(); 242 try { 243 System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength); 244 System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength); 245 recordingOffset = newRecordingOffset % maxLength; 246 } finally { 247 recordingBufferLock.unlock(); 248 } 249 } 250 251 record.stop(); 252 record.release(); 253 } 254 255 public synchronized void startRecognition() { 256 if (recognitionThread != null) { 257 return; 258 } 259 shouldContinueRecognition = true; 260 recognitionThread = 261 new Thread( 262 new Runnable() { 263 @Override 264 public void run() { 265 recognize(); 266 } 267 }); 268 recognitionThread.start(); 269 } 270 271 public synchronized void stopRecognition() { 272 if (recognitionThread == null) { 273 return; 274 } 275 shouldContinueRecognition = false; 276 recognitionThread = null; 277 } 278 279 private void recognize() { 280 Log.v(LOG_TAG, "Start recognition"); 281 282 short[] inputBuffer = new short[RECORDING_LENGTH]; 283 float[] floatInputBuffer = new float[RECORDING_LENGTH]; 284 float[] outputScores = new float[labels.size()]; 285 String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME}; 286 int[] sampleRateList = new int[] {SAMPLE_RATE}; 287 288 // Loop, grabbing recorded data and running the recognition model on it. 289 while (shouldContinueRecognition) { 290 // The recording thread places data in this round-robin buffer, so lock to 291 // make sure there's no writing happening and then copy it to our own 292 // local version. 293 recordingBufferLock.lock(); 294 try { 295 int maxLength = recordingBuffer.length; 296 int firstCopyLength = maxLength - recordingOffset; 297 int secondCopyLength = recordingOffset; 298 System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength); 299 System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength); 300 } finally { 301 recordingBufferLock.unlock(); 302 } 303 304 // We need to feed in float values between -1.0f and 1.0f, so divide the 305 // signed 16-bit inputs. 306 for (int i = 0; i < RECORDING_LENGTH; ++i) { 307 floatInputBuffer[i] = inputBuffer[i] / 32767.0f; 308 } 309 310 // Run the model. 311 inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList); 312 inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1); 313 inferenceInterface.run(outputScoresNames); 314 inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores); 315 316 // Use the smoother to figure out if we've had a real recognition event. 317 long currentTime = System.currentTimeMillis(); 318 final RecognizeCommands.RecognitionResult result = 319 recognizeCommands.processLatestResults(outputScores, currentTime); 320 321 runOnUiThread( 322 new Runnable() { 323 @Override 324 public void run() { 325 // If we do have a new command, highlight the right list entry. 326 if (!result.foundCommand.startsWith("_") && result.isNewCommand) { 327 int labelIndex = -1; 328 for (int i = 0; i < labels.size(); ++i) { 329 if (labels.get(i).equals(result.foundCommand)) { 330 labelIndex = i; 331 } 332 } 333 final View labelView = labelsListView.getChildAt(labelIndex - 2); 334 335 AnimatorSet colorAnimation = (AnimatorSet) AnimatorInflater.loadAnimator( 336 SpeechActivity.this, R.animator.color_animation); 337 colorAnimation.setTarget(labelView); 338 colorAnimation.start(); 339 } 340 } 341 }); 342 try { 343 // We don't need to run too frequently, so snooze for a bit. 344 Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS); 345 } catch (InterruptedException e) { 346 // Ignore 347 } 348 } 349 350 Log.v(LOG_TAG, "End recognition"); 351 } 352 } 353