Home | History | Annotate | Download | only in blasbenchmark
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.example.android.rs.blasbenchmark;
     18 
     19 import android.renderscript.*;
     20 import android.util.Log;
     21 import java.util.Random;
     22 import java.lang.Math;
     23 
     24 public class SGEMMTest extends TestBase {
     25 
     26     static {
     27         System.loadLibrary("gemmdata");
     28     }
     29 
     30     native void getData(byte[] a, byte[] b, byte[] c);
     31 
     32     ScriptIntrinsicBLAS mBLAS;
     33     private Allocation matA;
     34     private Allocation matB;
     35     private Allocation matC;
     36 
     37     private int m;
     38     private int n;
     39     private int k;
     40 
     41     private int a_offset;
     42     private int b_offset;
     43     private int mTestSize;
     44     private final float allowedError = 0.000001f;
     45 
     46     SGEMMTest(int testSize) {
     47         mTestSize = testSize;
     48     }
     49 
     50     public void createTest() {
     51         mBLAS = ScriptIntrinsicBLAS.create(mRS);
     52         setTest();
     53     }
     54 
     55     private void setTest() {
     56         switch (mTestSize) {
     57             case 1:
     58                 setTestSmall();
     59                 break;
     60             case 2:
     61                 setTestMedium();
     62                 break;
     63             case 3:
     64                 setTestLarge();
     65                 break;
     66             default:
     67                 break;
     68         }
     69     }
     70 
     71     // Calculate the square of the L2 norm of a matrix.
     72     private float calcL2Norm(float[] input) {
     73         float l2Norm = 0.f;
     74         for (int i = 0; i < input.length; ++i) {
     75             l2Norm += input[i] * input[i];
     76         }
     77         return l2Norm;
     78     }
     79 
     80     // Test whether the error of each element is samller the allowed error range.
     81     private boolean testWithTolerance(float[] out, float[] ref) {
     82         float l2NormOut = calcL2Norm(out);
     83         float l2NormRef = calcL2Norm(ref);
     84         float tolerance = allowedError * (l2NormOut < l2NormRef ? l2NormOut : l2NormRef);
     85         tolerance /= m * n;
     86         for (int i = 0; i < out.length; ++i) {
     87             float err = out[i] - ref[i];
     88             float absErr = err * err;
     89             if (absErr > tolerance) {
     90                 return false;
     91             }
     92         }
     93         return true;
     94     }
     95 
     96     // Transform byte data into float, given a offset.
     97     private float[] byteToFloat(byte[] input, int offset) {
     98         float[] output = new float[input.length];
     99         for (int i = 0; i < input.length; ++i) {
    100             output[i] = (float)(input[i] - offset);
    101         }
    102         return output;
    103     }
    104 
    105     // Calculate the reference result for C = A*B
    106     private float[] getGEMMResult(int m, int n, int k, float[] a_float, float[] b_float) {
    107         float[] c_float = new float[m * n];
    108         for (int j = 0; j < n; j++) {
    109             for (int i = 0; i < m; i++) {
    110                 float total = 0.f;
    111                 for (int l = 0; l < k; l++) {
    112                     int a_index = ((i * k) + l);
    113                     int b_index = ((l * n) + j);
    114                     float mult = a_float[a_index] * b_float[b_index];
    115                     total += mult;
    116                 }
    117                 int c_index = ((i * n) + j);
    118                 c_float[c_index] = total;
    119             }
    120         }
    121         return c_float;
    122     }
    123 
    124     // This test multiplies a couple of small float matrices, and compares the
    125     // results with java-calculated expectations. The data here is arbitrary.
    126     public void setTestSmall() {
    127         m = 2;
    128         n = 4;
    129         k = 3;
    130         a_offset = 0;
    131         b_offset = 12;
    132 
    133         float[] a_float = byteToFloat(new byte[] {
    134                 1, 2, 3,
    135                 4, 5, 6,
    136             }, a_offset);
    137 
    138         float[] b_float = byteToFloat(new byte[] {
    139                 11, 7, 3,
    140                 10, 6, 2,
    141                 9, 5, 1,
    142                 8, 4, 0,
    143             }, b_offset);
    144 
    145         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
    146         Type a_type = builder.setX(k).setY(m).create();
    147         Type b_type = builder.setX(n).setY(k).create();
    148         Type c_type = builder.setX(n).setY(m).create();
    149 
    150         matA = Allocation.createTyped(mRS, a_type);
    151         matB = Allocation.createTyped(mRS, b_type);
    152         matC = Allocation.createTyped(mRS, c_type);
    153 
    154         matA.copyFrom(a_float);
    155         matB.copyFrom(b_float);
    156 
    157         //During setup, do a sample run to see if the result is correct.
    158         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
    159                     1.0f, matA, matB, 0.f, matC);
    160         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
    161         float[] c_float_out = new float[m * n];
    162         matC.copyTo(c_float_out);
    163         if (!testWithTolerance(c_float_ref, c_float_out)) {
    164             Log.e(TAG, "Result is not correct!");
    165             throw new AssertionError("Result is not correct.");
    166         }
    167     }
    168 
    169     // This test multiplies another two medium matrices, and compares the
    170     // results with the expected values. The data here is arbitrary.
    171     public void setTestMedium() {
    172         m = 7;
    173         n = 9;
    174         k = 23;
    175         a_offset = 13;
    176         b_offset = 23;
    177 
    178         float[] a_float = byteToFloat(new byte[] {
    179                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    180                 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
    181                 1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12,
    182                 23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12,
    183                 1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
    184                 3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11,
    185                 8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19,
    186             }, a_offset);
    187 
    188         float[] b_float = byteToFloat(new byte[] {
    189                 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9,
    190                 0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9,
    191                 1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9,
    192                 0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8,
    193                 2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9,
    194                 0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7,
    195                 3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9,
    196                 0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6,
    197                 10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3,
    198             }, b_offset);
    199 
    200         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
    201         Type a_type = builder.setX(k).setY(m).create();
    202         Type b_type = builder.setX(n).setY(k).create();
    203         Type c_type = builder.setX(n).setY(m).create();
    204 
    205         matA = Allocation.createTyped(mRS, a_type);
    206         matB = Allocation.createTyped(mRS, b_type);
    207         matC = Allocation.createTyped(mRS, c_type);
    208 
    209         matA.copyFrom(a_float);
    210         matB.copyFrom(b_float);
    211 
    212         //During setup, do a sample run to see if the result is correct.
    213         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
    214                     1.0f, matA, matB, 0.f, matC);
    215         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
    216         float[] c_float_out = new float[m * n];
    217         matC.copyTo(c_float_out);
    218         if (!testWithTolerance(c_float_ref, c_float_out)) {
    219             Log.e(TAG, "Result is not correct!");
    220             throw new AssertionError("Result is not correct.");
    221         }
    222     }
    223 
    224 
    225     // This test takes a large set of real data captured from a convolutional
    226     // neural network solving a computer vision problem, and runs it through SGEMM.
    227     public void setTestLarge() {
    228 
    229         m = 256;
    230         n = 192;
    231         k = 1152;
    232         a_offset = 0;
    233         b_offset = 84;
    234 
    235         int a_count = (m * k);
    236         int b_count = (n * k);
    237         int c_count = (m * n);
    238 
    239         byte[] a_byte = new byte[a_count];
    240         byte[] b_byte = new byte[b_count];
    241         byte[] c_byte = new byte[c_count];
    242 
    243         getData(a_byte, b_byte, c_byte);
    244 
    245         float[] a_float = byteToFloat(a_byte, a_offset);
    246         float[] b_float = byteToFloat(b_byte, b_offset);
    247 
    248         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
    249         Type a_type = builder.setX(k).setY(m).create();
    250         Type b_type = builder.setX(n).setY(k).create();
    251         Type c_type = builder.setX(n).setY(m).create();
    252 
    253         matA = Allocation.createTyped(mRS, a_type);
    254         matB = Allocation.createTyped(mRS, b_type);
    255         matC = Allocation.createTyped(mRS, c_type);
    256 
    257         matA.copyFrom(a_float);
    258         matB.copyFrom(b_float);
    259 
    260         //During setup, do a sample run to see if the result is correct.
    261         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
    262                     1.0f, matA, matB, 0.f, matC);
    263         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
    264         float[] c_float_out = new float[c_count];
    265         matC.copyTo(c_float_out);
    266         if (!testWithTolerance(c_float_ref, c_float_out)) {
    267             Log.e(TAG, "Result is not correct!");
    268             throw new AssertionError("Result is not correct.");
    269         }
    270     }
    271 
    272     public void runTest() {
    273         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
    274                     1.0f, matA, matB, 0.f, matC);
    275     }
    276 
    277     public String getTestInfo() {
    278         return "SGEMM Test: m=" + m + ", n=" + n + ", k=" + k;
    279     }
    280 }
    281