1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.example.android.rs.blasbenchmark; 18 19 import android.renderscript.*; 20 import android.util.Log; 21 import java.util.Random; 22 import java.lang.Math; 23 24 public class BNNMTest extends TestBase { 25 26 static { 27 System.loadLibrary("gemmdata"); 28 } 29 30 native void getData(byte[] a, byte[] b, byte[] c); 31 32 ScriptIntrinsicBLAS mBLAS; 33 private Allocation matA; 34 private Allocation matB; 35 private Allocation matC; 36 37 private int m; 38 private int n; 39 private int k; 40 41 private int a_offset; 42 private int b_offset; 43 private int c_offset; 44 private int c_mult_int; 45 46 private int mTestSize; 47 48 BNNMTest(int testSize) { 49 mTestSize = testSize; 50 } 51 52 public void createTest() { 53 mBLAS = ScriptIntrinsicBLAS.create(mRS); 54 setTest(); 55 } 56 57 private void setTest() { 58 switch (mTestSize) { 59 case 1: 60 setTestSmall(); 61 break; 62 case 2: 63 setTestMedium(); 64 break; 65 case 3: 66 setTestLarge(); 67 break; 68 default: 69 break; 70 } 71 } 72 73 // In Java, the eight-bit 'byte' type is signed, but the API for the 8-bit 74 // matrix multiplication deals with unsigned bytes. This is a convenience 75 // function that converts arrays of unsigned ints to their equivalent 76 // representations as signed bytes. For example, the bit pattern 0xff is 255 77 // as an unsigned value, but -127 as a Java signed byte. So if you pass in an 78 // array of int[] {255} into this function, you'll get back byte[] {-127}. 79 private byte[] unsignedToSignedByte(int[] input) { 80 byte[] output = new byte[input.length]; 81 for (int i = 0; i < input.length; ++i) { 82 output[i] = (byte)(input[i]); 83 } 84 return output; 85 } 86 87 88 private void addByteNoise(byte[] data, int count, float frequency, int maxDelta) { 89 Random rand = new Random(); 90 for (int n = 0; n < count; ++n) { 91 if (rand.nextFloat() < frequency) { 92 final int originalValue = data[n]; 93 final float direction = rand.nextFloat(); 94 int delta = (int)(Math.ceil(rand.nextFloat() * maxDelta)); 95 if (direction < 0.5f) { 96 delta = -delta; 97 } 98 int newValue = (originalValue + delta); 99 if (newValue < -127) { 100 newValue = -127; 101 } 102 if (newValue > 127) { 103 newValue = 127; 104 } 105 data[n] = (byte)(newValue); 106 } 107 } 108 } 109 110 private boolean testWithTolerance(byte[] c_byte, byte[] c_byte_output) { 111 112 // The testing procedure here is a bit complex, but the aim is to mimic the 113 // requirements we've empirically found running deep neural networks in real 114 // applications. We want to open the door to vendors using approximations that 115 // produce slightly different results for optimization's sake, but keep the 116 // precision loss within small enough bounds that we don't lose accuracy in 117 // the final result. 118 // After experimentation, we've found that we can tolerate around 5% of the 119 // output bytes being different by 1. Any larger differences are not tolerable 120 // and we can't get good results if the frequency of small differences is 121 // higher than 5%. This test tries to measure those properties on an example 122 // set of parameters that were captured from a real application. 123 // For example, if you uncommented this function that adds random noise to the 124 // results at a 3% specified frequency, the test should fail: 125 // AddByteNoise(c_byte_output, c_count, 0.03f, 1); 126 127 final boolean areSizesDifferent = (c_byte.length != c_byte_output.length); 128 final int c_count = Math.min(c_byte.length, c_byte_output.length); 129 130 int howManyDifferent = 0; 131 boolean areAnyTooDifferent = false; 132 for (int i = 0; i < c_count; i++) { 133 byte expectedValue = c_byte[i]; 134 byte actualValue = c_byte_output[i]; 135 int delta = (expectedValue - actualValue); 136 // First make sure that the difference is no more than one. 137 if ((delta < -1) || (delta > 1)) { 138 areAnyTooDifferent = true; 139 } 140 // If there is a difference, increment the counter to track it. 141 if (delta != 0) { 142 // Don't spam the logs if too many are different. 143 if (howManyDifferent < 50) { 144 android.util.Log.e("BNNM", "Mismatch at " + i + 145 ": expected " + (expectedValue & 0xff) + 146 ", got " + (actualValue & 0xff)); 147 } 148 ++howManyDifferent; 149 } 150 } 151 // We want no more than 2% of the values to show any differences, so work out 152 // what that means in absolute numbers. 153 final int percentThreshold = 2; 154 final int differenceThreshold = Math.max((percentThreshold * c_count) / 100, 1); 155 final boolean areTooManyDifferent = (howManyDifferent >= differenceThreshold); 156 157 if (areAnyTooDifferent) { 158 android.util.Log.e("BNNM", "Some outputs were too different."); 159 } 160 161 if (areTooManyDifferent) { 162 android.util.Log.e("BNNM", "There were too many small differences." + 163 " We can tolerate " + percentThreshold + "% (" + 164 differenceThreshold + "), but there were " + howManyDifferent); 165 } 166 167 return !(areAnyTooDifferent || areTooManyDifferent); 168 } 169 170 // This test multiplies a couple of small 8-bit matrices, and compares the 171 // results with hand-calculated expectations. 172 public void setTestSmall() { 173 // The A matrix is: 174 // | 1 | 4 | 175 // | 2 | 5 | 176 // | 3 | 6 | 177 byte[] a_byte = unsignedToSignedByte(new int[] { 178 1, 2, 3, 179 4, 5, 6, 180 }); 181 final int a_rows = 3; 182 final int a_cols = 2; 183 a_offset = 0; 184 // The B matrix is: 185 // | -1 | -2 | -3 | -4 | 186 // | -5 | -6 | -7 | -8 | 187 // | -9 | -10 | -11 | -12 | 188 byte[] b_byte = unsignedToSignedByte(new int[] { 189 11, 7, 3, 190 10, 6, 2, 191 9, 5, 1, 192 8, 4, 0, 193 }); 194 final int b_cols = 4; 195 b_offset = 12; 196 // EightBitGemm implements C = B.transposed() * A, 197 // so we expect to get these results: 198 // 1*-1 + 2*-5 + 3*-9 + 128 = 90 199 // 1*-2 + 2*-6 + 3*-10 + 128 = 84 200 // 1*-3 + 2*-7 + 3*-11 + 128 = 78 201 // 1*-4 + 2*-8 + 3*-12 + 128 = 72 202 // 4*-1 + 5*-5 + 6*-9 + 128 = 45 203 // 4*-2 + 5*-6 + 6*-10 + 128 = 30 204 // 4*-3 + 5*-7 + 6*-11 + 128 = 15 205 // 4*-4 + 5*-8 + 6*-12 + 128 = 0 206 // | 90 | 45 | 207 // | 84 | 30 | 208 // | 78 | 15 | 209 // | 72 | 0 | 210 c_offset = 128; 211 final int c_shift = 21; 212 c_mult_int = (1 << c_shift); 213 byte[] expected_data = unsignedToSignedByte(new int[] { 214 90, 84, 78, 72, 215 45, 30, 15, 0, 216 }); 217 218 m = a_cols; 219 n = b_cols; 220 k = a_rows; 221 222 Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS)); 223 Type a_type = builder.setX(k).setY(m).create(); 224 Type b_type = builder.setX(k).setY(n).create(); 225 Type c_type = builder.setX(n).setY(m).create(); 226 227 matA = Allocation.createTyped(mRS, a_type); 228 matB = Allocation.createTyped(mRS, b_type); 229 matC = Allocation.createTyped(mRS, c_type); 230 matA.copyFrom(a_byte); 231 matB.copyFrom(b_byte); 232 233 //During setup, do a sample run to see if the result is correct. 234 mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int); 235 int c_count = (m * n); 236 byte[] c_byte_output = new byte[c_count]; 237 matC.copyTo(c_byte_output); 238 if (!testWithTolerance(expected_data, c_byte_output)) { 239 Log.e(TAG, "Result is not correct!"); 240 throw new AssertionError("Result is not correct."); 241 } 242 } 243 244 245 // This test multiplies another two medium 8-bit matrices, and compares the 246 // results with the expected values. The data here is arbitrary. 247 public void setTestMedium() { 248 byte[] a_byte = unsignedToSignedByte(new int[] { 249 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 250 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 251 1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12, 252 23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12, 253 1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 254 3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11, 255 8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19, 256 }); 257 final int a_rows = 23; 258 final int a_cols = 7; 259 a_offset = 13; 260 byte[] b_byte = unsignedToSignedByte(new int[] { 261 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 262 0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9, 263 1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9, 264 0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8, 265 2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9, 266 0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7, 267 3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9, 268 0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6, 269 10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3, 270 }); 271 final int b_cols = 9; 272 b_offset = 23; 273 c_offset = 2121; 274 final int c_shift = 21; 275 c_mult_int = 132359; 276 byte[] expected_data = unsignedToSignedByte(new int[] { 277 167, 53, 51, 54, 49, 55, 46, 278 56, 116, 153, 232, 232, 234, 231, 279 236, 232, 237, 174, 168, 131, 130, 280 132, 129, 133, 128, 133, 134, 151, 281 154, 152, 156, 151, 158, 150, 160, 282 156, 255, 113, 106, 120, 98, 127, 283 91, 134, 178, 231, 102, 97, 107, 284 92, 111, 87, 116, 164, 187, 76, 285 73, 78, 70, 81, 67, 83, 139, 286 }); 287 288 m = a_cols; 289 n = b_cols; 290 k = a_rows; 291 292 Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS)); 293 Type a_type = builder.setX(k).setY(m).create(); 294 Type b_type = builder.setX(k).setY(n).create(); 295 Type c_type = builder.setX(n).setY(m).create(); 296 297 matA = Allocation.createTyped(mRS, a_type); 298 matB = Allocation.createTyped(mRS, b_type); 299 matC = Allocation.createTyped(mRS, c_type); 300 301 matA.copyFrom(a_byte); 302 matB.copyFrom(b_byte); 303 304 //During setup, do a sample run to see if the result is correct. 305 mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int); 306 int c_count = (m * n); 307 byte[] c_byte_output = new byte[c_count]; 308 matC.copyTo(c_byte_output); 309 if (!testWithTolerance(expected_data, c_byte_output)) { 310 Log.e(TAG, "Result is not correct!"); 311 throw new AssertionError("Result is not correct."); 312 } 313 } 314 315 316 317 // This test takes a large set of real data captured from a convolutional 318 // neural network solving a computer vision problem, and runs it through the 319 // eight-bit matrix multiply. We test the results to make sure they're close 320 // enough to be usable. 321 public void setTestLarge() { 322 323 m = 256; 324 n = 192; 325 k = 1152; 326 a_offset = 0; 327 b_offset = 84; 328 c_mult_int = 3401; 329 c_offset = 74980; 330 331 int a_count = (m * k); 332 int b_count = (n * k); 333 int c_count = (m * n); 334 335 byte[] a_byte = new byte[a_count]; 336 byte[] b_byte = new byte[b_count]; 337 byte[] c_byte = new byte[c_count]; 338 339 getData(a_byte, b_byte, c_byte); 340 341 Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS)); 342 Type a_type = builder.setX(k).setY(m).create(); 343 Type b_type = builder.setX(k).setY(n).create(); 344 Type c_type = builder.setX(n).setY(m).create(); 345 346 matA = Allocation.createTyped(mRS, a_type); 347 matB = Allocation.createTyped(mRS, b_type); 348 matC = Allocation.createTyped(mRS, c_type); 349 350 matA.copyFrom(a_byte); 351 matB.copyFrom(b_byte); 352 353 //During setup, do a sample run to see if the result is correct. 354 mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int); 355 byte[] c_byte_output = new byte[c_count]; 356 matC.copyTo(c_byte_output); 357 if (!testWithTolerance(c_byte, c_byte_output)) { 358 Log.e(TAG, "Result is not correct!"); 359 throw new AssertionError("Result is not correct."); 360 } 361 } 362 363 public void runTest() { 364 mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int); 365 } 366 367 public String getTestInfo() { 368 return "8Bit GEMM Test: m=" + m + ", n=" + n + ", k=" + k; 369 } 370 } 371