1 /* 2 * Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.tensorflow.demo; 18 19 import android.graphics.Bitmap; 20 import android.graphics.Bitmap.Config; 21 import android.graphics.Canvas; 22 import android.graphics.Color; 23 import android.graphics.Matrix; 24 import android.graphics.Paint; 25 import android.graphics.Paint.Style; 26 import android.graphics.RectF; 27 import android.graphics.Typeface; 28 import android.media.ImageReader.OnImageAvailableListener; 29 import android.os.SystemClock; 30 import android.util.Size; 31 import android.util.TypedValue; 32 import android.widget.Toast; 33 import java.io.IOException; 34 import java.util.LinkedList; 35 import java.util.List; 36 import java.util.Vector; 37 import org.tensorflow.demo.OverlayView.DrawCallback; 38 import org.tensorflow.demo.env.BorderedText; 39 import org.tensorflow.demo.env.ImageUtils; 40 import org.tensorflow.demo.env.Logger; 41 import org.tensorflow.demo.tracking.MultiBoxTracker; 42 import org.tensorflow.lite.demo.R; // Explicit import needed for internal Google builds. 43 44 /** 45 * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track 46 * objects. 47 */ 48 public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { 49 private static final Logger LOGGER = new Logger(); 50 51 // Configuration values for the prepackaged SSD model. 52 private static final int TF_OD_API_INPUT_SIZE = 300; 53 private static final boolean TF_OD_API_IS_QUANTIZED = true; 54 private static final String TF_OD_API_MODEL_FILE = "detect.tflite"; 55 private static final String TF_OD_API_LABELS_FILE = "coco_labels_list.txt"; 56 57 // Which detection model to use: by default uses Tensorflow Object Detection API frozen 58 // checkpoints. 59 private enum DetectorMode { 60 TF_OD_API; 61 } 62 63 private static final DetectorMode MODE = DetectorMode.TF_OD_API; 64 65 // Minimum detection confidence to track a detection. 66 private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f; 67 68 private static final boolean MAINTAIN_ASPECT = false; 69 70 private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); 71 72 private static final boolean SAVE_PREVIEW_BITMAP = false; 73 private static final float TEXT_SIZE_DIP = 10; 74 75 private Integer sensorOrientation; 76 77 private Classifier detector; 78 79 private long lastProcessingTimeMs; 80 private Bitmap rgbFrameBitmap = null; 81 private Bitmap croppedBitmap = null; 82 private Bitmap cropCopyBitmap = null; 83 84 private boolean computingDetection = false; 85 86 private long timestamp = 0; 87 88 private Matrix frameToCropTransform; 89 private Matrix cropToFrameTransform; 90 91 private MultiBoxTracker tracker; 92 93 private byte[] luminanceCopy; 94 95 private BorderedText borderedText; 96 @Override 97 public void onPreviewSizeChosen(final Size size, final int rotation) { 98 final float textSizePx = 99 TypedValue.applyDimension( 100 TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); 101 borderedText = new BorderedText(textSizePx); 102 borderedText.setTypeface(Typeface.MONOSPACE); 103 104 tracker = new MultiBoxTracker(this); 105 106 int cropSize = TF_OD_API_INPUT_SIZE; 107 108 try { 109 detector = 110 TFLiteObjectDetectionAPIModel.create( 111 getAssets(), 112 TF_OD_API_MODEL_FILE, 113 TF_OD_API_LABELS_FILE, 114 TF_OD_API_INPUT_SIZE, 115 TF_OD_API_IS_QUANTIZED); 116 cropSize = TF_OD_API_INPUT_SIZE; 117 } catch (final IOException e) { 118 LOGGER.e("Exception initializing classifier!", e); 119 Toast toast = 120 Toast.makeText( 121 getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); 122 toast.show(); 123 finish(); 124 } 125 126 127 previewWidth = size.getWidth(); 128 previewHeight = size.getHeight(); 129 130 sensorOrientation = rotation - getScreenOrientation(); 131 LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); 132 133 LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); 134 rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); 135 croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); 136 137 frameToCropTransform = 138 ImageUtils.getTransformationMatrix( 139 previewWidth, previewHeight, 140 cropSize, cropSize, 141 sensorOrientation, MAINTAIN_ASPECT); 142 143 cropToFrameTransform = new Matrix(); 144 frameToCropTransform.invert(cropToFrameTransform); 145 146 trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); 147 trackingOverlay.addCallback( 148 new DrawCallback() { 149 @Override 150 public void drawCallback(final Canvas canvas) { 151 tracker.draw(canvas); 152 if (isDebug()) { 153 tracker.drawDebug(canvas); 154 } 155 } 156 }); 157 158 addCallback( 159 new DrawCallback() { 160 @Override 161 public void drawCallback(final Canvas canvas) { 162 if (!isDebug()) { 163 return; 164 } 165 final Bitmap copy = cropCopyBitmap; 166 if (copy == null) { 167 return; 168 } 169 170 final int backgroundColor = Color.argb(100, 0, 0, 0); 171 canvas.drawColor(backgroundColor); 172 173 final Matrix matrix = new Matrix(); 174 final float scaleFactor = 2; 175 matrix.postScale(scaleFactor, scaleFactor); 176 matrix.postTranslate( 177 canvas.getWidth() - copy.getWidth() * scaleFactor, 178 canvas.getHeight() - copy.getHeight() * scaleFactor); 179 canvas.drawBitmap(copy, matrix, new Paint()); 180 181 final Vector<String> lines = new Vector<String>(); 182 if (detector != null) { 183 final String statString = detector.getStatString(); 184 final String[] statLines = statString.split("\n"); 185 for (final String line : statLines) { 186 lines.add(line); 187 } 188 } 189 lines.add(""); 190 191 lines.add("Frame: " + previewWidth + "x" + previewHeight); 192 lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); 193 lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); 194 lines.add("Rotation: " + sensorOrientation); 195 lines.add("Inference time: " + lastProcessingTimeMs + "ms"); 196 197 borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); 198 } 199 }); 200 } 201 202 OverlayView trackingOverlay; 203 204 @Override 205 protected void processImage() { 206 ++timestamp; 207 final long currTimestamp = timestamp; 208 byte[] originalLuminance = getLuminance(); 209 tracker.onFrame( 210 previewWidth, 211 previewHeight, 212 getLuminanceStride(), 213 sensorOrientation, 214 originalLuminance, 215 timestamp); 216 trackingOverlay.postInvalidate(); 217 218 // No mutex needed as this method is not reentrant. 219 if (computingDetection) { 220 readyForNextImage(); 221 return; 222 } 223 computingDetection = true; 224 LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread."); 225 226 rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); 227 228 if (luminanceCopy == null) { 229 luminanceCopy = new byte[originalLuminance.length]; 230 } 231 System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); 232 readyForNextImage(); 233 234 final Canvas canvas = new Canvas(croppedBitmap); 235 canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); 236 // For examining the actual TF input. 237 if (SAVE_PREVIEW_BITMAP) { 238 ImageUtils.saveBitmap(croppedBitmap); 239 } 240 241 runInBackground( 242 new Runnable() { 243 @Override 244 public void run() { 245 LOGGER.i("Running detection on image " + currTimestamp); 246 final long startTime = SystemClock.uptimeMillis(); 247 final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); 248 lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; 249 250 cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); 251 final Canvas canvas = new Canvas(cropCopyBitmap); 252 final Paint paint = new Paint(); 253 paint.setColor(Color.RED); 254 paint.setStyle(Style.STROKE); 255 paint.setStrokeWidth(2.0f); 256 257 float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; 258 switch (MODE) { 259 case TF_OD_API: 260 minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; 261 break; 262 } 263 264 final List<Classifier.Recognition> mappedRecognitions = 265 new LinkedList<Classifier.Recognition>(); 266 267 for (final Classifier.Recognition result : results) { 268 final RectF location = result.getLocation(); 269 if (location != null && result.getConfidence() >= minimumConfidence) { 270 canvas.drawRect(location, paint); 271 272 cropToFrameTransform.mapRect(location); 273 result.setLocation(location); 274 mappedRecognitions.add(result); 275 } 276 } 277 278 tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); 279 trackingOverlay.postInvalidate(); 280 281 requestRender(); 282 computingDetection = false; 283 } 284 }); 285 } 286 287 @Override 288 protected int getLayoutId() { 289 return R.layout.camera_connection_fragment_tracking; 290 } 291 292 @Override 293 protected Size getDesiredPreviewFrameSize() { 294 return DESIRED_PREVIEW_SIZE; 295 } 296 297 @Override 298 public void onSetDebug(final boolean debug) { 299 detector.enableStatLogging(debug); 300 } 301 } 302