Home | History | Annotate | Download | only in blasbenchmark
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.example.android.rs.blasbenchmark;
     18 
     19 import android.renderscript.*;
     20 import android.util.Log;
     21 import java.util.Random;
     22 import java.lang.Math;
     23 
     24 public class BNNMTest extends TestBase {
     25 
     26     static {
     27         System.loadLibrary("gemmdata");
     28     }
     29 
     30     native void getData(byte[] a, byte[] b, byte[] c);
     31 
     32     ScriptIntrinsicBLAS mBLAS;
     33     private Allocation matA;
     34     private Allocation matB;
     35     private Allocation matC;
     36 
     37     private int m;
     38     private int n;
     39     private int k;
     40 
     41     private int a_offset;
     42     private int b_offset;
     43     private int c_offset;
     44     private int c_mult_int;
     45 
     46     private int mTestSize;
     47 
     48     BNNMTest(int testSize) {
     49         mTestSize = testSize;
     50     }
     51 
     52     public void createTest() {
     53         mBLAS = ScriptIntrinsicBLAS.create(mRS);
     54         setTest();
     55     }
     56 
     57     private void setTest() {
     58         switch (mTestSize) {
     59             case 1:
     60                 setTestSmall();
     61                 break;
     62             case 2:
     63                 setTestMedium();
     64                 break;
     65             case 3:
     66                 setTestLarge();
     67                 break;
     68             default:
     69                 break;
     70         }
     71     }
     72 
     73     // In Java, the eight-bit 'byte' type is signed, but the API for the 8-bit
     74     // matrix multiplication deals with unsigned bytes. This is a convenience
     75     // function that converts arrays of unsigned ints to their equivalent
     76     // representations as signed bytes. For example, the bit pattern 0xff is 255
     77     // as an unsigned value, but -127 as a Java signed byte. So if you pass in an
     78     // array of int[] {255} into this function, you'll get back byte[] {-127}.
     79     private byte[] unsignedToSignedByte(int[] input) {
     80         byte[] output = new byte[input.length];
     81         for (int i = 0; i < input.length; ++i) {
     82             output[i] = (byte)(input[i]);
     83         }
     84         return output;
     85     }
     86 
     87 
     88     private void addByteNoise(byte[] data, int count, float frequency, int maxDelta) {
     89         Random rand = new Random();
     90         for (int n = 0; n < count; ++n) {
     91             if (rand.nextFloat() < frequency) {
     92                 final int originalValue = data[n];
     93                 final float direction = rand.nextFloat();
     94                 int delta = (int)(Math.ceil(rand.nextFloat() * maxDelta));
     95                 if (direction < 0.5f) {
     96                     delta = -delta;
     97                 }
     98                 int newValue = (originalValue + delta);
     99                 if (newValue < -127) {
    100                     newValue = -127;
    101                 }
    102                 if (newValue > 127) {
    103                     newValue = 127;
    104                 }
    105                 data[n] = (byte)(newValue);
    106             }
    107         }
    108     }
    109 
    110     private boolean testWithTolerance(byte[] c_byte, byte[] c_byte_output) {
    111 
    112         // The testing procedure here is a bit complex, but the aim is to mimic the
    113         // requirements we've empirically found running deep neural networks in real
    114         // applications. We want to open the door to vendors using approximations that
    115         // produce slightly different results for optimization's sake, but keep the
    116         // precision loss within small enough bounds that we don't lose accuracy in
    117         // the final result.
    118         // After experimentation, we've found that we can tolerate around 5% of the
    119         // output bytes being different by 1. Any larger differences are not tolerable
    120         // and we can't get good results if the frequency of small differences is
    121         // higher than 5%. This test tries to measure those properties on an example
    122         // set of parameters that were captured from a real application.
    123         // For example, if you uncommented this function that adds random noise to the
    124         // results at a 3% specified frequency, the test should fail:
    125         // AddByteNoise(c_byte_output, c_count, 0.03f, 1);
    126 
    127         final boolean areSizesDifferent = (c_byte.length != c_byte_output.length);
    128         final int c_count = Math.min(c_byte.length, c_byte_output.length);
    129 
    130         int howManyDifferent = 0;
    131         boolean areAnyTooDifferent = false;
    132         for (int i = 0; i < c_count; i++) {
    133             byte expectedValue = c_byte[i];
    134             byte actualValue = c_byte_output[i];
    135             int delta = (expectedValue - actualValue);
    136             // First make sure that the difference is no more than one.
    137             if ((delta < -1) || (delta > 1)) {
    138                 areAnyTooDifferent = true;
    139             }
    140             // If there is a difference, increment the counter to track it.
    141             if (delta != 0) {
    142                 // Don't spam the logs if too many are different.
    143                 if (howManyDifferent < 50) {
    144                     android.util.Log.e("BNNM", "Mismatch at " + i +
    145                                        ": expected " + (expectedValue & 0xff) +
    146                                        ", got " + (actualValue & 0xff));
    147                 }
    148                 ++howManyDifferent;
    149             }
    150         }
    151         // We want no more than 2% of the values to show any differences, so work out
    152         // what that means in absolute numbers.
    153         final int percentThreshold = 2;
    154         final int differenceThreshold = Math.max((percentThreshold * c_count) / 100, 1);
    155         final boolean areTooManyDifferent = (howManyDifferent >= differenceThreshold);
    156 
    157         if (areAnyTooDifferent) {
    158             android.util.Log.e("BNNM", "Some outputs were too different.");
    159         }
    160 
    161         if (areTooManyDifferent) {
    162             android.util.Log.e("BNNM", "There were too many small differences." +
    163                                " We can tolerate " + percentThreshold + "% (" +
    164                                differenceThreshold + "), but there were " + howManyDifferent);
    165         }
    166 
    167         return !(areAnyTooDifferent || areTooManyDifferent);
    168     }
    169 
    170     // This test multiplies a couple of small 8-bit matrices, and compares the
    171     // results with hand-calculated expectations.
    172     public void setTestSmall() {
    173         // The A matrix is:
    174         // |   1 |   4 |
    175         // |   2 |   5 |
    176         // |   3 |   6 |
    177         byte[] a_byte = unsignedToSignedByte(new int[] {
    178                 1, 2, 3,
    179                 4, 5, 6,
    180             });
    181         final int a_rows = 3;
    182         final int a_cols = 2;
    183         a_offset = 0;
    184         // The B matrix is:
    185         // |  -1 |  -2 |  -3 |  -4 |
    186         // |  -5 |  -6 |  -7 |  -8 |
    187         // |  -9 | -10 | -11 | -12 |
    188         byte[] b_byte = unsignedToSignedByte(new int[] {
    189                 11, 7, 3,
    190                 10, 6, 2,
    191                 9, 5, 1,
    192                 8, 4, 0,
    193             });
    194         final int b_cols = 4;
    195         b_offset = 12;
    196         // EightBitGemm implements C = B.transposed() * A,
    197         // so we expect to get these results:
    198         // 1*-1 + 2*-5 + 3*-9 + 128 = 90
    199         // 1*-2 + 2*-6 + 3*-10 + 128 = 84
    200         // 1*-3 + 2*-7 + 3*-11 + 128 = 78
    201         // 1*-4 + 2*-8 + 3*-12 + 128 = 72
    202         // 4*-1 + 5*-5 + 6*-9 + 128 = 45
    203         // 4*-2 + 5*-6 + 6*-10 + 128 = 30
    204         // 4*-3 + 5*-7 + 6*-11 + 128 = 15
    205         // 4*-4 + 5*-8 + 6*-12 + 128 = 0
    206         // | 90 |  45 |
    207         // | 84 |  30 |
    208         // | 78 | 15 |
    209         // | 72 | 0 |
    210         c_offset = 128;
    211         final int c_shift = 21;
    212         c_mult_int = (1 << c_shift);
    213         byte[] expected_data = unsignedToSignedByte(new int[] {
    214                 90, 84, 78, 72,
    215                 45, 30, 15, 0,
    216             });
    217 
    218         m = a_cols;
    219         n = b_cols;
    220         k = a_rows;
    221 
    222         Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
    223         Type a_type = builder.setX(k).setY(m).create();
    224         Type b_type = builder.setX(k).setY(n).create();
    225         Type c_type = builder.setX(n).setY(m).create();
    226 
    227         matA = Allocation.createTyped(mRS, a_type);
    228         matB = Allocation.createTyped(mRS, b_type);
    229         matC = Allocation.createTyped(mRS, c_type);
    230         matA.copyFrom(a_byte);
    231         matB.copyFrom(b_byte);
    232 
    233         //During setup, do a sample run to see if the result is correct.
    234         mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
    235         int c_count = (m * n);
    236         byte[] c_byte_output = new byte[c_count];
    237         matC.copyTo(c_byte_output);
    238         if (!testWithTolerance(expected_data, c_byte_output)) {
    239             Log.e(TAG, "Result is not correct!");
    240             throw new AssertionError("Result is not correct.");
    241         }
    242     }
    243 
    244 
    245     // This test multiplies another two medium 8-bit matrices, and compares the
    246     // results with the expected values. The data here is arbitrary.
    247     public void setTestMedium() {
    248         byte[] a_byte = unsignedToSignedByte(new int[] {
    249                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    250                 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
    251                 1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12,
    252                 23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12,
    253                 1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
    254                 3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11,
    255                 8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19,
    256             });
    257         final int a_rows = 23;
    258         final int a_cols = 7;
    259         a_offset = 13;
    260         byte[] b_byte = unsignedToSignedByte(new int[] {
    261                 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9,
    262                 0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9,
    263                 1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9,
    264                 0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8,
    265                 2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9,
    266                 0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7,
    267                 3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9,
    268                 0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6,
    269                 10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3,
    270             });
    271         final int b_cols = 9;
    272         b_offset = 23;
    273         c_offset = 2121;
    274         final int c_shift = 21;
    275         c_mult_int = 132359;
    276         byte[] expected_data = unsignedToSignedByte(new int[] {
    277                 167, 53, 51, 54, 49, 55, 46,
    278                 56, 116, 153, 232, 232, 234, 231,
    279                 236, 232, 237, 174, 168, 131, 130,
    280                 132, 129, 133, 128, 133, 134, 151,
    281                 154, 152, 156, 151, 158, 150, 160,
    282                 156, 255, 113, 106, 120, 98, 127,
    283                 91, 134, 178, 231, 102, 97, 107,
    284                 92, 111, 87, 116, 164, 187, 76,
    285                 73, 78, 70, 81, 67, 83, 139,
    286             });
    287 
    288         m = a_cols;
    289         n = b_cols;
    290         k = a_rows;
    291 
    292         Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
    293         Type a_type = builder.setX(k).setY(m).create();
    294         Type b_type = builder.setX(k).setY(n).create();
    295         Type c_type = builder.setX(n).setY(m).create();
    296 
    297         matA = Allocation.createTyped(mRS, a_type);
    298         matB = Allocation.createTyped(mRS, b_type);
    299         matC = Allocation.createTyped(mRS, c_type);
    300 
    301         matA.copyFrom(a_byte);
    302         matB.copyFrom(b_byte);
    303 
    304         //During setup, do a sample run to see if the result is correct.
    305         mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
    306         int c_count = (m * n);
    307         byte[] c_byte_output = new byte[c_count];
    308         matC.copyTo(c_byte_output);
    309         if (!testWithTolerance(expected_data, c_byte_output)) {
    310             Log.e(TAG, "Result is not correct!");
    311             throw new AssertionError("Result is not correct.");
    312         }
    313     }
    314 
    315 
    316 
    317     // This test takes a large set of real data captured from a convolutional
    318     // neural network solving a computer vision problem, and runs it through the
    319     // eight-bit matrix multiply. We test the results to make sure they're close
    320     // enough to be usable.
    321     public void setTestLarge() {
    322 
    323         m = 256;
    324         n = 192;
    325         k = 1152;
    326         a_offset = 0;
    327         b_offset = 84;
    328         c_mult_int = 3401;
    329         c_offset = 74980;
    330 
    331         int a_count = (m * k);
    332         int b_count = (n * k);
    333         int c_count = (m * n);
    334 
    335         byte[] a_byte = new byte[a_count];
    336         byte[] b_byte = new byte[b_count];
    337         byte[] c_byte = new byte[c_count];
    338 
    339         getData(a_byte, b_byte, c_byte);
    340 
    341         Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
    342         Type a_type = builder.setX(k).setY(m).create();
    343         Type b_type = builder.setX(k).setY(n).create();
    344         Type c_type = builder.setX(n).setY(m).create();
    345 
    346         matA = Allocation.createTyped(mRS, a_type);
    347         matB = Allocation.createTyped(mRS, b_type);
    348         matC = Allocation.createTyped(mRS, c_type);
    349 
    350         matA.copyFrom(a_byte);
    351         matB.copyFrom(b_byte);
    352 
    353         //During setup, do a sample run to see if the result is correct.
    354         mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
    355         byte[] c_byte_output = new byte[c_count];
    356         matC.copyTo(c_byte_output);
    357         if (!testWithTolerance(c_byte, c_byte_output)) {
    358             Log.e(TAG, "Result is not correct!");
    359             throw new AssertionError("Result is not correct.");
    360         }
    361     }
    362 
    363     public void runTest() {
    364         mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
    365     }
    366 
    367     public String getTestInfo() {
    368         return "8Bit GEMM Test: m=" + m + ", n=" + n + ", k=" + k;
    369     }
    370 }
    371