1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import static com.google.common.truth.Truth.assertThat; 19 import static org.junit.Assert.fail; 20 21 import java.io.File; 22 import java.nio.MappedByteBuffer; 23 import java.nio.channels.FileChannel; 24 import java.nio.file.Files; 25 import java.nio.file.Path; 26 import java.nio.file.StandardOpenOption; 27 import java.util.EnumSet; 28 import java.util.HashMap; 29 import java.util.Map; 30 import org.junit.Test; 31 import org.junit.runner.RunWith; 32 import org.junit.runners.JUnit4; 33 34 /** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ 35 @RunWith(JUnit4.class) 36 public final class InterpreterTest { 37 38 private static final File MODEL_FILE = 39 new File("tensorflow/contrib/lite/java/src/testdata/add.bin"); 40 41 private static final File MOBILENET_MODEL_FILE = 42 new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); 43 44 @Test 45 public void testInterpreter() throws Exception { 46 Interpreter interpreter = new Interpreter(MODEL_FILE); 47 assertThat(interpreter).isNotNull(); 48 interpreter.close(); 49 } 50 51 @Test 52 public void testRunWithMappedByteBufferModel() throws Exception { 53 Path path = MODEL_FILE.toPath(); 54 FileChannel fileChannel = 55 (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); 56 MappedByteBuffer mappedByteBuffer = 57 fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); 58 Interpreter interpreter = new Interpreter(mappedByteBuffer); 59 float[] oneD = {1.23f, 6.54f, 7.81f}; 60 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 61 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 62 float[][][][] fourD = {threeD, threeD}; 63 float[][][][] parsedOutputs = new float[2][8][8][3]; 64 interpreter.run(fourD, parsedOutputs); 65 float[] outputOneD = parsedOutputs[0][0][0]; 66 float[] expected = {3.69f, 19.62f, 23.43f}; 67 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 68 interpreter.close(); 69 fileChannel.close(); 70 } 71 72 @Test 73 public void testRun() { 74 Interpreter interpreter = new Interpreter(MODEL_FILE); 75 Float[] oneD = {1.23f, 6.54f, 7.81f}; 76 Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 77 Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 78 Float[][][][] fourD = {threeD, threeD}; 79 Float[][][][] parsedOutputs = new Float[2][8][8][3]; 80 try { 81 interpreter.run(fourD, parsedOutputs); 82 fail(); 83 } catch (IllegalArgumentException e) { 84 assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); 85 } 86 interpreter.close(); 87 } 88 89 @Test 90 public void testRunWithBoxedInputs() { 91 Interpreter interpreter = new Interpreter(MODEL_FILE); 92 float[] oneD = {1.23f, 6.54f, 7.81f}; 93 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 94 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 95 float[][][][] fourD = {threeD, threeD}; 96 float[][][][] parsedOutputs = new float[2][8][8][3]; 97 interpreter.run(fourD, parsedOutputs); 98 float[] outputOneD = parsedOutputs[0][0][0]; 99 float[] expected = {3.69f, 19.62f, 23.43f}; 100 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 101 interpreter.close(); 102 } 103 104 @Test 105 public void testRunForMultipleInputsOutputs() { 106 Interpreter interpreter = new Interpreter(MODEL_FILE); 107 float[] oneD = {1.23f, 6.54f, 7.81f}; 108 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 109 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 110 float[][][][] fourD = {threeD, threeD}; 111 Object[] inputs = {fourD}; 112 float[][][][] parsedOutputs = new float[2][8][8][3]; 113 Map<Integer, Object> outputs = new HashMap<>(); 114 outputs.put(0, parsedOutputs); 115 interpreter.runForMultipleInputsOutputs(inputs, outputs); 116 float[] outputOneD = parsedOutputs[0][0][0]; 117 float[] expected = {3.69f, 19.62f, 23.43f}; 118 assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); 119 interpreter.close(); 120 } 121 122 @Test 123 public void testMobilenetRun() { 124 // Create a gray image. 125 float[][][][] img = new float[1][224][224][3]; 126 for (int i = 0; i < 224; ++i) { 127 for (int j = 0; j < 224; ++j) { 128 img[0][i][j][0] = 0.5f; 129 img[0][i][j][1] = 0.5f; 130 img[0][i][j][2] = 0.5f; 131 } 132 } 133 134 // Allocate memory to receive the output values. 135 float[][] labels = new float[1][1001]; 136 137 Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); 138 interpreter.run(img, labels); 139 interpreter.close(); 140 141 assertThat(labels[0]) 142 .usingExactEquality() 143 .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); 144 } 145 146 @Test 147 public void testRunWithWrongInputType() { 148 Interpreter interpreter = new Interpreter(MODEL_FILE); 149 int[] oneD = {4, 3, 9}; 150 int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 151 int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 152 int[][][][] fourD = {threeD, threeD}; 153 float[][][][] parsedOutputs = new float[2][8][8][3]; 154 try { 155 interpreter.run(fourD, parsedOutputs); 156 fail(); 157 } catch (IllegalArgumentException e) { 158 assertThat(e) 159 .hasMessageThat() 160 .contains( 161 "DataType (2) of input data does not match with the DataType (1) of model inputs."); 162 } 163 interpreter.close(); 164 } 165 166 @Test 167 public void testRunWithWrongOutputType() { 168 Interpreter interpreter = new Interpreter(MODEL_FILE); 169 float[] oneD = {1.23f, 6.54f, 7.81f}; 170 float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; 171 float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; 172 float[][][][] fourD = {threeD, threeD}; 173 int[][][][] parsedOutputs = new int[2][8][8][3]; 174 try { 175 interpreter.run(fourD, parsedOutputs); 176 fail(); 177 } catch (IllegalArgumentException e) { 178 assertThat(e) 179 .hasMessageThat() 180 .contains( 181 "Cannot convert an TensorFlowLite tensor with type " 182 + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" 183 + " TensorFlowLite type INT32)"); 184 } 185 interpreter.close(); 186 } 187 188 @Test 189 public void testGetInputIndex() { 190 Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); 191 try { 192 interpreter.getInputIndex("WrongInputName"); 193 fail(); 194 } catch (IllegalArgumentException e) { 195 assertThat(e) 196 .hasMessageThat() 197 .contains( 198 "WrongInputName is not a valid name for any input. The indexes of the inputs" 199 + " are {input=0}"); 200 } 201 int index = interpreter.getInputIndex("input"); 202 assertThat(index).isEqualTo(0); 203 } 204 205 @Test 206 public void testGetOutputIndex() { 207 Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); 208 try { 209 interpreter.getOutputIndex("WrongOutputName"); 210 fail(); 211 } catch (IllegalArgumentException e) { 212 assertThat(e) 213 .hasMessageThat() 214 .contains( 215 "WrongOutputName is not a valid name for any output. The indexes of the outputs" 216 + " are {MobilenetV1/Predictions/Softmax=0}"); 217 } 218 int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); 219 assertThat(index).isEqualTo(0); 220 } 221 } 222