Home | History | Annotate | Download | only in blasbenchmark
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.example.android.rs.blasbenchmark;
     18 
     19 
     20 import android.app.Activity;
     21 import android.os.Bundle;
     22 import android.util.Log;
     23 
     24 import com.example.android.rs.blasbenchmark.BlasTestList.TestName;
     25 
     26 import android.test.ActivityInstrumentationTestCase2;
     27 import android.test.suitebuilder.annotation.MediumTest;
     28 
     29 /**
     30  * BLAS benchmark test.
     31  * To run the test, please use command
     32  *
     33  * adb shell am instrument -w com.example.android.rs.blasbenchmark/android.support.test.runner.AndroidJUnitRunner
     34  *
     35  */
     36 public class BlasTest extends ActivityInstrumentationTestCase2<BlasBenchmark> {
     37     private final String TAG = "BLAS Test";
     38     // Only run 1 iteration now to fit the MediumTest time requirement.
     39     // One iteration means running the tests continuous for 1s.
     40     private int mIteration = 1;
     41     private BlasBenchmark mActivity;
     42 
     43     public BlasTest() {
     44         super(BlasBenchmark.class);
     45     }
     46 
     47     // Initialize the parameter for ImageProcessingActivityJB.
     48     protected void prepareTest() {
     49         mActivity = getActivity();
     50         mActivity.prepareInstrumentationTest();
     51     }
     52 
     53     @Override
     54     public void setUp() throws Exception {
     55         super.setUp();
     56         prepareTest();
     57         setActivityInitialTouchMode(false);
     58    }
     59 
     60     @Override
     61     public void tearDown() throws Exception {
     62         super.tearDown();
     63     }
     64 
     65     class TestAction implements Runnable {
     66         TestName mTestName;
     67         float mResult;
     68         public TestAction(TestName testName) {
     69             mTestName = testName;
     70         }
     71         public void run() {
     72             mResult = mActivity.mProcessor.getInstrumentationResult(mTestName);
     73             Log.v(TAG, "Benchmark for test \"" + mTestName.toString() + "\" is: " + mResult);
     74             synchronized(this) {
     75                 this.notify();
     76             }
     77         }
     78         public float getBenchmark() {
     79             return mResult;
     80         }
     81     }
     82 
     83     // Set the benchmark thread to run on ui thread
     84     // Synchronized the thread such that the test will wait for the benchmark thread to finish
     85     public void runOnUiThread(Runnable action) {
     86         synchronized(action) {
     87             mActivity.runOnUiThread(action);
     88             try {
     89                 action.wait();
     90             } catch (InterruptedException e) {
     91                 Log.v(TAG, "waiting for action running on UI thread is interrupted: " +
     92                         e.toString());
     93             }
     94         }
     95     }
     96 
     97     public void runTest(TestAction ta, String testName) {
     98         float sum = 0;
     99         for (int i = 0; i < mIteration; i++) {
    100             runOnUiThread(ta);
    101             float bmValue = ta.getBenchmark();
    102             Log.v(TAG, "results for iteration " + i + " is " + bmValue);
    103             sum += bmValue;
    104         }
    105         float avgResult = sum/mIteration;
    106 
    107         // post result to INSTRUMENTATION_STATUS
    108         Bundle results = new Bundle();
    109         results.putFloat(testName + "_avg", avgResult);
    110         getInstrumentation().sendStatus(Activity.RESULT_OK, results);
    111     }
    112 
    113     // Test case 0: SGEMM Test Small
    114     @MediumTest
    115     public void testSGEMMSmall() {
    116         TestAction ta = new TestAction(TestName.SGEMM_SMALL);
    117         runTest(ta, TestName.SGEMM_SMALL.name());
    118     }
    119 
    120     // Test case 1: SGEMM Test Medium
    121     @MediumTest
    122     public void testSGEMMedium() {
    123         TestAction ta = new TestAction(TestName.SGEMM_MEDIUM);
    124         runTest(ta, TestName.SGEMM_MEDIUM.name());
    125     }
    126 
    127     // Test case 2: SGEMM Test Large
    128     @MediumTest
    129     public void testSGEMMLarge() {
    130         TestAction ta = new TestAction(TestName.SGEMM_LARGE);
    131         runTest(ta, TestName.SGEMM_LARGE.name());
    132     }
    133 
    134     // Test case 3: 8Bit GEMM Test Small
    135     @MediumTest
    136     public void testBNNMSmall() {
    137         TestAction ta = new TestAction(TestName.BNNM_SMALL);
    138         runTest(ta, TestName.BNNM_SMALL.name());
    139     }
    140 
    141     // Test case 4: 8Bit GEMM Test Medium
    142     @MediumTest
    143     public void testBNNMMMedium() {
    144         TestAction ta = new TestAction(TestName.BNNM_MEDIUM);
    145         runTest(ta, TestName.BNNM_MEDIUM.name());
    146     }
    147 
    148     // Test case 5: 8Bit GEMM Test Large
    149     @MediumTest
    150     public void testBNNMLarge() {
    151         TestAction ta = new TestAction(TestName.BNNM_LARGE);
    152         runTest(ta, TestName.BNNM_LARGE.name());
    153     }
    154 
    155     // Test case 6: SGEMM GoogLeNet Test
    156     @MediumTest
    157     public void testSGEMMGoogLeNet() {
    158         TestAction ta = new TestAction(TestName.SGEMM_GoogLeNet);
    159         runTest(ta, TestName.SGEMM_GoogLeNet.name());
    160     }
    161 
    162     // Test case 7: 8Bit GEMM GoogLeNet Test
    163     @MediumTest
    164     public void testBNNMGoogLeNet() {
    165         TestAction ta = new TestAction(TestName.BNNM_GoogLeNet);
    166         runTest(ta, TestName.BNNM_GoogLeNet.name());
    167     }
    168 
    169     // Test case 8: SGEMM GoogLeNet Test Padded
    170     @MediumTest
    171     public void testSGEMMGoogLeNetPadded() {
    172         TestAction ta = new TestAction(TestName.SGEMM_GoogLeNet_Padded);
    173         runTest(ta, TestName.SGEMM_GoogLeNet_Padded.name());
    174     }
    175 
    176     // Test case 9: 8Bit GEMM GoogLeNet Test Padded
    177     @MediumTest
    178     public void testBNNMGoogLeNetPadded() {
    179         TestAction ta = new TestAction(TestName.BNNM_GoogLeNet_Padded);
    180         runTest(ta, TestName.BNNM_GoogLeNet_Padded.name());
    181     }
    182 }
    183