Home | History | Annotate | Download | only in demo
      1 /*
      2  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *       http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /* Demonstrates how to run an audio recognition model in Android.
     18 
     19 This example loads a simple speech recognition model trained by the tutorial at
     20 https://www.tensorflow.org/tutorials/audio_training
     21 
     22 The model files should be downloaded automatically from the TensorFlow website,
     23 but if you have a custom model you can update the LABEL_FILENAME and
     24 MODEL_FILENAME constants to point to your own files.
     25 
     26 The example application displays a list view with all of the known audio labels,
     27 and highlights each one when it thinks it has detected one through the
     28 microphone. The averaging of results to give a more reliable signal happens in
     29 the RecognizeCommands helper class.
     30 */
     31 
     32 package org.tensorflow.demo;
     33 
     34 import android.animation.AnimatorInflater;
     35 import android.animation.AnimatorSet;
     36 import android.app.Activity;
     37 import android.content.pm.PackageManager;
     38 import android.media.AudioFormat;
     39 import android.media.AudioRecord;
     40 import android.media.MediaRecorder;
     41 import android.os.Build;
     42 import android.os.Bundle;
     43 import android.util.Log;
     44 import android.view.View;
     45 import android.widget.ArrayAdapter;
     46 import android.widget.Button;
     47 import android.widget.ListView;
     48 import java.io.BufferedReader;
     49 import java.io.IOException;
     50 import java.io.InputStreamReader;
     51 import java.util.ArrayList;
     52 import java.util.List;
     53 import java.util.concurrent.locks.ReentrantLock;
     54 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
     55 import org.tensorflow.demo.R;
     56 
     57 /**
     58  * An activity that listens for audio and then uses a TensorFlow model to detect particular classes,
     59  * by default a small set of action words.
     60  */
     61 public class SpeechActivity extends Activity {
     62 
     63   // Constants that control the behavior of the recognition code and model
     64   // settings. See the audio recognition tutorial for a detailed explanation of
     65   // all these, but you should customize them to match your training settings if
     66   // you are running your own model.
     67   private static final int SAMPLE_RATE = 16000;
     68   private static final int SAMPLE_DURATION_MS = 1000;
     69   private static final int RECORDING_LENGTH = (int) (SAMPLE_RATE * SAMPLE_DURATION_MS / 1000);
     70   private static final long AVERAGE_WINDOW_DURATION_MS = 500;
     71   private static final float DETECTION_THRESHOLD = 0.70f;
     72   private static final int SUPPRESSION_MS = 1500;
     73   private static final int MINIMUM_COUNT = 3;
     74   private static final long MINIMUM_TIME_BETWEEN_SAMPLES_MS = 30;
     75   private static final String LABEL_FILENAME = "file:///android_asset/conv_actions_labels.txt";
     76   private static final String MODEL_FILENAME = "file:///android_asset/conv_actions_frozen.pb";
     77   private static final String INPUT_DATA_NAME = "decoded_sample_data:0";
     78   private static final String SAMPLE_RATE_NAME = "decoded_sample_data:1";
     79   private static final String OUTPUT_SCORES_NAME = "labels_softmax";
     80 
     81   // UI elements.
     82   private static final int REQUEST_RECORD_AUDIO = 13;
     83   private Button quitButton;
     84   private ListView labelsListView;
     85   private static final String LOG_TAG = SpeechActivity.class.getSimpleName();
     86 
     87   // Working variables.
     88   short[] recordingBuffer = new short[RECORDING_LENGTH];
     89   int recordingOffset = 0;
     90   boolean shouldContinue = true;
     91   private Thread recordingThread;
     92   boolean shouldContinueRecognition = true;
     93   private Thread recognitionThread;
     94   private final ReentrantLock recordingBufferLock = new ReentrantLock();
     95   private TensorFlowInferenceInterface inferenceInterface;
     96   private List<String> labels = new ArrayList<String>();
     97   private List<String> displayedLabels = new ArrayList<>();
     98   private RecognizeCommands recognizeCommands = null;
     99 
    100   @Override
    101   protected void onCreate(Bundle savedInstanceState) {
    102     // Set up the UI.
    103     super.onCreate(savedInstanceState);
    104     setContentView(R.layout.activity_speech);
    105     quitButton = (Button) findViewById(R.id.quit);
    106     quitButton.setOnClickListener(
    107         new View.OnClickListener() {
    108           @Override
    109           public void onClick(View view) {
    110             moveTaskToBack(true);
    111             android.os.Process.killProcess(android.os.Process.myPid());
    112             System.exit(1);
    113           }
    114         });
    115     labelsListView = (ListView) findViewById(R.id.list_view);
    116 
    117     // Load the labels for the model, but only display those that don't start
    118     // with an underscore.
    119     String actualFilename = LABEL_FILENAME.split("file:///android_asset/")[1];
    120     Log.i(LOG_TAG, "Reading labels from: " + actualFilename);
    121     BufferedReader br = null;
    122     try {
    123       br = new BufferedReader(new InputStreamReader(getAssets().open(actualFilename)));
    124       String line;
    125       while ((line = br.readLine()) != null) {
    126         labels.add(line);
    127         if (line.charAt(0) != '_') {
    128           displayedLabels.add(line.substring(0, 1).toUpperCase() + line.substring(1));
    129         }
    130       }
    131       br.close();
    132     } catch (IOException e) {
    133       throw new RuntimeException("Problem reading label file!", e);
    134     }
    135 
    136     // Build a list view based on these labels.
    137     ArrayAdapter<String> arrayAdapter =
    138         new ArrayAdapter<String>(this, R.layout.list_text_item, displayedLabels);
    139     labelsListView.setAdapter(arrayAdapter);
    140 
    141     // Set up an object to smooth recognition results to increase accuracy.
    142     recognizeCommands =
    143         new RecognizeCommands(
    144             labels,
    145             AVERAGE_WINDOW_DURATION_MS,
    146             DETECTION_THRESHOLD,
    147             SUPPRESSION_MS,
    148             MINIMUM_COUNT,
    149             MINIMUM_TIME_BETWEEN_SAMPLES_MS);
    150 
    151     // Load the TensorFlow model.
    152     inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILENAME);
    153 
    154     // Start the recording and recognition threads.
    155     requestMicrophonePermission();
    156     startRecording();
    157     startRecognition();
    158   }
    159 
    160   private void requestMicrophonePermission() {
    161     if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    162       requestPermissions(
    163           new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
    164     }
    165   }
    166 
    167   @Override
    168   public void onRequestPermissionsResult(
    169       int requestCode, String[] permissions, int[] grantResults) {
    170     if (requestCode == REQUEST_RECORD_AUDIO
    171         && grantResults.length > 0
    172         && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
    173       startRecording();
    174       startRecognition();
    175     }
    176   }
    177 
    178   public synchronized void startRecording() {
    179     if (recordingThread != null) {
    180       return;
    181     }
    182     shouldContinue = true;
    183     recordingThread =
    184         new Thread(
    185             new Runnable() {
    186               @Override
    187               public void run() {
    188                 record();
    189               }
    190             });
    191     recordingThread.start();
    192   }
    193 
    194   public synchronized void stopRecording() {
    195     if (recordingThread == null) {
    196       return;
    197     }
    198     shouldContinue = false;
    199     recordingThread = null;
    200   }
    201 
    202   private void record() {
    203     android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
    204 
    205     // Estimate the buffer size we'll need for this device.
    206     int bufferSize =
    207         AudioRecord.getMinBufferSize(
    208             SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
    209     if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) {
    210       bufferSize = SAMPLE_RATE * 2;
    211     }
    212     short[] audioBuffer = new short[bufferSize / 2];
    213 
    214     AudioRecord record =
    215         new AudioRecord(
    216             MediaRecorder.AudioSource.DEFAULT,
    217             SAMPLE_RATE,
    218             AudioFormat.CHANNEL_IN_MONO,
    219             AudioFormat.ENCODING_PCM_16BIT,
    220             bufferSize);
    221 
    222     if (record.getState() != AudioRecord.STATE_INITIALIZED) {
    223       Log.e(LOG_TAG, "Audio Record can't initialize!");
    224       return;
    225     }
    226 
    227     record.startRecording();
    228 
    229     Log.v(LOG_TAG, "Start recording");
    230 
    231     // Loop, gathering audio data and copying it to a round-robin buffer.
    232     while (shouldContinue) {
    233       int numberRead = record.read(audioBuffer, 0, audioBuffer.length);
    234       int maxLength = recordingBuffer.length;
    235       int newRecordingOffset = recordingOffset + numberRead;
    236       int secondCopyLength = Math.max(0, newRecordingOffset - maxLength);
    237       int firstCopyLength = numberRead - secondCopyLength;
    238       // We store off all the data for the recognition thread to access. The ML
    239       // thread will copy out of this buffer into its own, while holding the
    240       // lock, so this should be thread safe.
    241       recordingBufferLock.lock();
    242       try {
    243         System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, firstCopyLength);
    244         System.arraycopy(audioBuffer, firstCopyLength, recordingBuffer, 0, secondCopyLength);
    245         recordingOffset = newRecordingOffset % maxLength;
    246       } finally {
    247         recordingBufferLock.unlock();
    248       }
    249     }
    250 
    251     record.stop();
    252     record.release();
    253   }
    254 
    255   public synchronized void startRecognition() {
    256     if (recognitionThread != null) {
    257       return;
    258     }
    259     shouldContinueRecognition = true;
    260     recognitionThread =
    261         new Thread(
    262             new Runnable() {
    263               @Override
    264               public void run() {
    265                 recognize();
    266               }
    267             });
    268     recognitionThread.start();
    269   }
    270 
    271   public synchronized void stopRecognition() {
    272     if (recognitionThread == null) {
    273       return;
    274     }
    275     shouldContinueRecognition = false;
    276     recognitionThread = null;
    277   }
    278 
    279   private void recognize() {
    280     Log.v(LOG_TAG, "Start recognition");
    281 
    282     short[] inputBuffer = new short[RECORDING_LENGTH];
    283     float[] floatInputBuffer = new float[RECORDING_LENGTH];
    284     float[] outputScores = new float[labels.size()];
    285     String[] outputScoresNames = new String[] {OUTPUT_SCORES_NAME};
    286     int[] sampleRateList = new int[] {SAMPLE_RATE};
    287 
    288     // Loop, grabbing recorded data and running the recognition model on it.
    289     while (shouldContinueRecognition) {
    290       // The recording thread places data in this round-robin buffer, so lock to
    291       // make sure there's no writing happening and then copy it to our own
    292       // local version.
    293       recordingBufferLock.lock();
    294       try {
    295         int maxLength = recordingBuffer.length;
    296         int firstCopyLength = maxLength - recordingOffset;
    297         int secondCopyLength = recordingOffset;
    298         System.arraycopy(recordingBuffer, recordingOffset, inputBuffer, 0, firstCopyLength);
    299         System.arraycopy(recordingBuffer, 0, inputBuffer, firstCopyLength, secondCopyLength);
    300       } finally {
    301         recordingBufferLock.unlock();
    302       }
    303 
    304       // We need to feed in float values between -1.0f and 1.0f, so divide the
    305       // signed 16-bit inputs.
    306       for (int i = 0; i < RECORDING_LENGTH; ++i) {
    307         floatInputBuffer[i] = inputBuffer[i] / 32767.0f;
    308       }
    309 
    310       // Run the model.
    311       inferenceInterface.feed(SAMPLE_RATE_NAME, sampleRateList);
    312       inferenceInterface.feed(INPUT_DATA_NAME, floatInputBuffer, RECORDING_LENGTH, 1);
    313       inferenceInterface.run(outputScoresNames);
    314       inferenceInterface.fetch(OUTPUT_SCORES_NAME, outputScores);
    315 
    316       // Use the smoother to figure out if we've had a real recognition event.
    317       long currentTime = System.currentTimeMillis();
    318       final RecognizeCommands.RecognitionResult result =
    319           recognizeCommands.processLatestResults(outputScores, currentTime);
    320 
    321       runOnUiThread(
    322           new Runnable() {
    323             @Override
    324             public void run() {
    325               // If we do have a new command, highlight the right list entry.
    326               if (!result.foundCommand.startsWith("_") && result.isNewCommand) {
    327                 int labelIndex = -1;
    328                 for (int i = 0; i < labels.size(); ++i) {
    329                   if (labels.get(i).equals(result.foundCommand)) {
    330                     labelIndex = i;
    331                   }
    332                 }
    333                 final View labelView = labelsListView.getChildAt(labelIndex - 2);
    334 
    335                 AnimatorSet colorAnimation = (AnimatorSet) AnimatorInflater.loadAnimator(
    336                     SpeechActivity.this, R.animator.color_animation);
    337                 colorAnimation.setTarget(labelView);
    338                 colorAnimation.start();
    339               }
    340             }
    341           });
    342       try {
    343         // We don't need to run too frequently, so snooze for a bit.
    344         Thread.sleep(MINIMUM_TIME_BETWEEN_SAMPLES_MS);
    345       } catch (InterruptedException e) {
    346         // Ignore
    347       }
    348     }
    349 
    350     Log.v(LOG_TAG, "End recognition");
    351   }
    352 }
    353