1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.example.android.nn.benchmark; 18 19 20 import android.app.Activity; 21 import android.os.Bundle; 22 import android.test.ActivityInstrumentationTestCase2; 23 import android.test.suitebuilder.annotation.MediumTest; 24 import android.util.Log; 25 26 import com.example.android.nn.benchmark.NNTestList.TestName; 27 28 /** 29 * NNAPI benchmark test. 30 * To run the test, please use command 31 * 32 * adb shell am instrument -w com.example.android.nn.benchmark/android.support.test.runner.AndroidJUnitRunner 33 * 34 */ 35 public class NNTest extends ActivityInstrumentationTestCase2<NNBenchmark> { 36 // Only run 1 iteration now to fit the MediumTest time requirement. 37 // One iteration means running the tests continuous for 1s. 38 private int mIteration = 1; 39 private NNBenchmark mActivity; 40 41 public NNTest() { 42 super(NNBenchmark.class); 43 } 44 45 // Initialize the parameter for ImageProcessingActivityJB. 46 protected void prepareTest() { 47 mActivity = getActivity(); 48 mActivity.prepareInstrumentationTest(); 49 } 50 51 @Override 52 public void setUp() throws Exception { 53 super.setUp(); 54 prepareTest(); 55 setActivityInitialTouchMode(false); 56 } 57 58 @Override 59 public void tearDown() throws Exception { 60 super.tearDown(); 61 } 62 63 class TestAction implements Runnable { 64 TestName mTestName; 65 float mResult; 66 public TestAction(TestName testName) { 67 mTestName = testName; 68 } 69 public void run() { 70 mResult = mActivity.mProcessor.getInstrumentationResult(mTestName); 71 Log.v(NNBenchmark.TAG, 72 "Benchmark for test \"" + mTestName.toString() + "\" is: " + mResult); 73 synchronized(this) { 74 this.notify(); 75 } 76 } 77 public float getBenchmark() { 78 return mResult; 79 } 80 } 81 82 // Set the benchmark thread to run on ui thread 83 // Synchronized the thread such that the test will wait for the benchmark thread to finish 84 public void runOnUiThread(Runnable action) { 85 synchronized(action) { 86 mActivity.runOnUiThread(action); 87 try { 88 action.wait(); 89 } catch (InterruptedException e) { 90 Log.v(NNBenchmark.TAG, "waiting for action running on UI thread is interrupted: " + 91 e.toString()); 92 } 93 } 94 } 95 96 public void runTest(TestAction ta, String testName) { 97 float sum = 0; 98 for (int i = 0; i < mIteration; i++) { 99 runOnUiThread(ta); 100 float bmValue = ta.getBenchmark(); 101 Log.v(NNBenchmark.TAG, "results for iteration " + i + " is " + bmValue); 102 sum += bmValue; 103 } 104 float avgResult = sum/mIteration; 105 106 // post result to INSTRUMENTATION_STATUS 107 Bundle results = new Bundle(); 108 results.putFloat(testName + "_avg", avgResult); 109 getInstrumentation().sendStatus(Activity.RESULT_OK, results); 110 } 111 112 // Test case 0: MobileNet float32 113 @MediumTest 114 public void testMobileNetFloat() { 115 TestAction ta = new TestAction(TestName.MobileNet_FLOAT); 116 runTest(ta, TestName.MobileNet_FLOAT.name()); 117 } 118 119 // Test case 1: MobileNet quantized 120 @MediumTest 121 public void testMobileNetQuantized() { 122 TestAction ta = new TestAction(TestName.MobileNet_QUANT8); 123 runTest(ta, TestName.MobileNet_QUANT8.name()); 124 } 125 }