Home | History | Annotate | Download | only in demo
      1 /*
      2  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *       http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.tensorflow.demo;
     18 
     19 import android.graphics.Bitmap;
     20 import android.graphics.Bitmap.Config;
     21 import android.graphics.Canvas;
     22 import android.graphics.Color;
     23 import android.graphics.Matrix;
     24 import android.graphics.Paint;
     25 import android.graphics.Paint.Style;
     26 import android.graphics.RectF;
     27 import android.graphics.Typeface;
     28 import android.media.ImageReader.OnImageAvailableListener;
     29 import android.os.SystemClock;
     30 import android.util.Size;
     31 import android.util.TypedValue;
     32 import android.widget.Toast;
     33 import java.io.IOException;
     34 import java.util.LinkedList;
     35 import java.util.List;
     36 import java.util.Vector;
     37 import org.tensorflow.demo.OverlayView.DrawCallback;
     38 import org.tensorflow.demo.env.BorderedText;
     39 import org.tensorflow.demo.env.ImageUtils;
     40 import org.tensorflow.demo.env.Logger;
     41 import org.tensorflow.demo.tracking.MultiBoxTracker;
     42 import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds.
     43 
     44 /**
     45  * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
     46  * objects.
     47  */
     48 public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
     49   private static final Logger LOGGER = new Logger();
     50 
     51   // Configuration values for the prepackaged SSD model.
     52   private static final int TF_OD_API_INPUT_SIZE = 300;
     53   private static final boolean TF_OD_API_IS_QUANTIZED = true;
     54   private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
     55   private static final String TF_OD_API_LABELS_FILE = "coco_labels_list.txt";
     56 
     57   // Which detection model to use: by default uses Tensorflow Object Detection API frozen
     58   // checkpoints.
     59   private enum DetectorMode {
     60     TF_OD_API;
     61   }
     62 
     63   private static final DetectorMode MODE = DetectorMode.TF_OD_API;
     64 
     65   // Minimum detection confidence to track a detection.
     66   private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
     67 
     68   private static final boolean MAINTAIN_ASPECT = false;
     69 
     70   private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
     71 
     72   private static final boolean SAVE_PREVIEW_BITMAP = false;
     73   private static final float TEXT_SIZE_DIP = 10;
     74 
     75   private Integer sensorOrientation;
     76 
     77   private Classifier detector;
     78 
     79   private long lastProcessingTimeMs;
     80   private Bitmap rgbFrameBitmap = null;
     81   private Bitmap croppedBitmap = null;
     82   private Bitmap cropCopyBitmap = null;
     83 
     84   private boolean computingDetection = false;
     85 
     86   private long timestamp = 0;
     87 
     88   private Matrix frameToCropTransform;
     89   private Matrix cropToFrameTransform;
     90 
     91   private MultiBoxTracker tracker;
     92 
     93   private byte[] luminanceCopy;
     94 
     95   private BorderedText borderedText;
     96   @Override
     97   public void onPreviewSizeChosen(final Size size, final int rotation) {
     98     final float textSizePx =
     99         TypedValue.applyDimension(
    100             TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
    101     borderedText = new BorderedText(textSizePx);
    102     borderedText.setTypeface(Typeface.MONOSPACE);
    103 
    104     tracker = new MultiBoxTracker(this);
    105 
    106     int cropSize = TF_OD_API_INPUT_SIZE;
    107 
    108     try {
    109       detector =
    110           TFLiteObjectDetectionAPIModel.create(
    111               getAssets(),
    112               TF_OD_API_MODEL_FILE,
    113               TF_OD_API_LABELS_FILE,
    114               TF_OD_API_INPUT_SIZE,
    115               TF_OD_API_IS_QUANTIZED);
    116       cropSize = TF_OD_API_INPUT_SIZE;
    117     } catch (final IOException e) {
    118       LOGGER.e("Exception initializing classifier!", e);
    119       Toast toast =
    120           Toast.makeText(
    121               getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
    122       toast.show();
    123       finish();
    124     }
    125 
    126 
    127     previewWidth = size.getWidth();
    128     previewHeight = size.getHeight();
    129 
    130     sensorOrientation = rotation - getScreenOrientation();
    131     LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
    132 
    133     LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
    134     rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
    135     croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
    136 
    137     frameToCropTransform =
    138         ImageUtils.getTransformationMatrix(
    139             previewWidth, previewHeight,
    140             cropSize, cropSize,
    141             sensorOrientation, MAINTAIN_ASPECT);
    142 
    143     cropToFrameTransform = new Matrix();
    144     frameToCropTransform.invert(cropToFrameTransform);
    145 
    146     trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
    147     trackingOverlay.addCallback(
    148         new DrawCallback() {
    149           @Override
    150           public void drawCallback(final Canvas canvas) {
    151             tracker.draw(canvas);
    152             if (isDebug()) {
    153               tracker.drawDebug(canvas);
    154             }
    155           }
    156         });
    157 
    158     addCallback(
    159         new DrawCallback() {
    160           @Override
    161           public void drawCallback(final Canvas canvas) {
    162             if (!isDebug()) {
    163               return;
    164             }
    165             final Bitmap copy = cropCopyBitmap;
    166             if (copy == null) {
    167               return;
    168             }
    169 
    170             final int backgroundColor = Color.argb(100, 0, 0, 0);
    171             canvas.drawColor(backgroundColor);
    172 
    173             final Matrix matrix = new Matrix();
    174             final float scaleFactor = 2;
    175             matrix.postScale(scaleFactor, scaleFactor);
    176             matrix.postTranslate(
    177                 canvas.getWidth() - copy.getWidth() * scaleFactor,
    178                 canvas.getHeight() - copy.getHeight() * scaleFactor);
    179             canvas.drawBitmap(copy, matrix, new Paint());
    180 
    181             final Vector<String> lines = new Vector<String>();
    182             if (detector != null) {
    183               final String statString = detector.getStatString();
    184               final String[] statLines = statString.split("\n");
    185               for (final String line : statLines) {
    186                 lines.add(line);
    187               }
    188             }
    189             lines.add("");
    190 
    191             lines.add("Frame: " + previewWidth + "x" + previewHeight);
    192             lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
    193             lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
    194             lines.add("Rotation: " + sensorOrientation);
    195             lines.add("Inference time: " + lastProcessingTimeMs + "ms");
    196 
    197             borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
    198           }
    199         });
    200   }
    201 
    202   OverlayView trackingOverlay;
    203 
    204   @Override
    205   protected void processImage() {
    206     ++timestamp;
    207     final long currTimestamp = timestamp;
    208     byte[] originalLuminance = getLuminance();
    209     tracker.onFrame(
    210         previewWidth,
    211         previewHeight,
    212         getLuminanceStride(),
    213         sensorOrientation,
    214         originalLuminance,
    215         timestamp);
    216     trackingOverlay.postInvalidate();
    217 
    218     // No mutex needed as this method is not reentrant.
    219     if (computingDetection) {
    220       readyForNextImage();
    221       return;
    222     }
    223     computingDetection = true;
    224     LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
    225 
    226     rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
    227 
    228     if (luminanceCopy == null) {
    229       luminanceCopy = new byte[originalLuminance.length];
    230     }
    231     System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length);
    232     readyForNextImage();
    233 
    234     final Canvas canvas = new Canvas(croppedBitmap);
    235     canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
    236     // For examining the actual TF input.
    237     if (SAVE_PREVIEW_BITMAP) {
    238       ImageUtils.saveBitmap(croppedBitmap);
    239     }
    240 
    241     runInBackground(
    242         new Runnable() {
    243           @Override
    244           public void run() {
    245             LOGGER.i("Running detection on image " + currTimestamp);
    246             final long startTime = SystemClock.uptimeMillis();
    247             final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap);
    248             lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
    249 
    250             cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
    251             final Canvas canvas = new Canvas(cropCopyBitmap);
    252             final Paint paint = new Paint();
    253             paint.setColor(Color.RED);
    254             paint.setStyle(Style.STROKE);
    255             paint.setStrokeWidth(2.0f);
    256 
    257             float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
    258             switch (MODE) {
    259               case TF_OD_API:
    260                 minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
    261                 break;
    262             }
    263 
    264             final List<Classifier.Recognition> mappedRecognitions =
    265                 new LinkedList<Classifier.Recognition>();
    266 
    267             for (final Classifier.Recognition result : results) {
    268               final RectF location = result.getLocation();
    269               if (location != null && result.getConfidence() >= minimumConfidence) {
    270                 canvas.drawRect(location, paint);
    271 
    272                 cropToFrameTransform.mapRect(location);
    273                 result.setLocation(location);
    274                 mappedRecognitions.add(result);
    275               }
    276             }
    277 
    278             tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp);
    279             trackingOverlay.postInvalidate();
    280 
    281             requestRender();
    282             computingDetection = false;
    283           }
    284         });
    285   }
    286 
    287   @Override
    288   protected int getLayoutId() {
    289     return R.layout.camera_connection_fragment_tracking;
    290   }
    291 
    292   @Override
    293   protected Size getDesiredPreviewFrameSize() {
    294     return DESIRED_PREVIEW_SIZE;
    295   }
    296 
    297   @Override
    298   public void onSetDebug(final boolean debug) {
    299     detector.enableStatLogging(debug);
    300   }
    301 }
    302