Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 package org.tensorflow.lite;
     17 
     18 import static com.google.common.truth.Truth.assertThat;
     19 import static org.junit.Assert.fail;
     20 
     21 import java.io.File;
     22 import java.nio.MappedByteBuffer;
     23 import java.nio.channels.FileChannel;
     24 import java.nio.file.Files;
     25 import java.nio.file.Path;
     26 import java.nio.file.StandardOpenOption;
     27 import java.util.EnumSet;
     28 import java.util.HashMap;
     29 import java.util.Map;
     30 import org.junit.Test;
     31 import org.junit.runner.RunWith;
     32 import org.junit.runners.JUnit4;
     33 
     34 /** Unit tests for {@link org.tensorflow.lite.Interpreter}. */
     35 @RunWith(JUnit4.class)
     36 public final class InterpreterTest {
     37 
     38   private static final File MODEL_FILE =
     39       new File("tensorflow/contrib/lite/java/src/testdata/add.bin");
     40 
     41   private static final File MOBILENET_MODEL_FILE =
     42       new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
     43 
     44   @Test
     45   public void testInterpreter() throws Exception {
     46     Interpreter interpreter = new Interpreter(MODEL_FILE);
     47     assertThat(interpreter).isNotNull();
     48     interpreter.close();
     49   }
     50 
     51   @Test
     52   public void testRunWithMappedByteBufferModel() throws Exception {
     53     Path path = MODEL_FILE.toPath();
     54     FileChannel fileChannel =
     55         (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
     56     MappedByteBuffer mappedByteBuffer =
     57         fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
     58     Interpreter interpreter = new Interpreter(mappedByteBuffer);
     59     float[] oneD = {1.23f, 6.54f, 7.81f};
     60     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
     61     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
     62     float[][][][] fourD = {threeD, threeD};
     63     float[][][][] parsedOutputs = new float[2][8][8][3];
     64     interpreter.run(fourD, parsedOutputs);
     65     float[] outputOneD = parsedOutputs[0][0][0];
     66     float[] expected = {3.69f, 19.62f, 23.43f};
     67     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
     68     interpreter.close();
     69     fileChannel.close();
     70   }
     71 
     72   @Test
     73   public void testRun() {
     74     Interpreter interpreter = new Interpreter(MODEL_FILE);
     75     Float[] oneD = {1.23f, 6.54f, 7.81f};
     76     Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
     77     Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
     78     Float[][][][] fourD = {threeD, threeD};
     79     Float[][][][] parsedOutputs = new Float[2][8][8][3];
     80     try {
     81       interpreter.run(fourD, parsedOutputs);
     82       fail();
     83     } catch (IllegalArgumentException e) {
     84       assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;");
     85     }
     86     interpreter.close();
     87   }
     88 
     89   @Test
     90   public void testRunWithBoxedInputs() {
     91     Interpreter interpreter = new Interpreter(MODEL_FILE);
     92     float[] oneD = {1.23f, 6.54f, 7.81f};
     93     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
     94     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
     95     float[][][][] fourD = {threeD, threeD};
     96     float[][][][] parsedOutputs = new float[2][8][8][3];
     97     interpreter.run(fourD, parsedOutputs);
     98     float[] outputOneD = parsedOutputs[0][0][0];
     99     float[] expected = {3.69f, 19.62f, 23.43f};
    100     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
    101     interpreter.close();
    102   }
    103 
    104   @Test
    105   public void testRunForMultipleInputsOutputs() {
    106     Interpreter interpreter = new Interpreter(MODEL_FILE);
    107     float[] oneD = {1.23f, 6.54f, 7.81f};
    108     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
    109     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
    110     float[][][][] fourD = {threeD, threeD};
    111     Object[] inputs = {fourD};
    112     float[][][][] parsedOutputs = new float[2][8][8][3];
    113     Map<Integer, Object> outputs = new HashMap<>();
    114     outputs.put(0, parsedOutputs);
    115     interpreter.runForMultipleInputsOutputs(inputs, outputs);
    116     float[] outputOneD = parsedOutputs[0][0][0];
    117     float[] expected = {3.69f, 19.62f, 23.43f};
    118     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
    119     interpreter.close();
    120   }
    121 
    122   @Test
    123   public void testMobilenetRun() {
    124     // Create a gray image.
    125     float[][][][] img = new float[1][224][224][3];
    126     for (int i = 0; i < 224; ++i) {
    127       for (int j = 0; j < 224; ++j) {
    128         img[0][i][j][0] = 0.5f;
    129         img[0][i][j][1] = 0.5f;
    130         img[0][i][j][2] = 0.5f;
    131       }
    132     }
    133 
    134     // Allocate memory to receive the output values.
    135     float[][] labels = new float[1][1001];
    136 
    137     Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
    138     interpreter.run(img, labels);
    139     interpreter.close();
    140 
    141     assertThat(labels[0])
    142         .usingExactEquality()
    143         .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY});
    144   }
    145 
    146   @Test
    147   public void testRunWithWrongInputType() {
    148     Interpreter interpreter = new Interpreter(MODEL_FILE);
    149     int[] oneD = {4, 3, 9};
    150     int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
    151     int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
    152     int[][][][] fourD = {threeD, threeD};
    153     float[][][][] parsedOutputs = new float[2][8][8][3];
    154     try {
    155       interpreter.run(fourD, parsedOutputs);
    156       fail();
    157     } catch (IllegalArgumentException e) {
    158       assertThat(e)
    159           .hasMessageThat()
    160           .contains(
    161               "DataType (2) of input data does not match with the DataType (1) of model inputs.");
    162     }
    163     interpreter.close();
    164   }
    165 
    166   @Test
    167   public void testRunWithWrongOutputType() {
    168     Interpreter interpreter = new Interpreter(MODEL_FILE);
    169     float[] oneD = {1.23f, 6.54f, 7.81f};
    170     float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
    171     float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
    172     float[][][][] fourD = {threeD, threeD};
    173     int[][][][] parsedOutputs = new int[2][8][8][3];
    174     try {
    175       interpreter.run(fourD, parsedOutputs);
    176       fail();
    177     } catch (IllegalArgumentException e) {
    178       assertThat(e)
    179           .hasMessageThat()
    180           .contains(
    181               "Cannot convert an TensorFlowLite tensor with type "
    182                   + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
    183                   + " TensorFlowLite type INT32)");
    184     }
    185     interpreter.close();
    186   }
    187 
    188   @Test
    189   public void testGetInputIndex() {
    190     Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
    191     try {
    192       interpreter.getInputIndex("WrongInputName");
    193       fail();
    194     } catch (IllegalArgumentException e) {
    195       assertThat(e)
    196           .hasMessageThat()
    197           .contains(
    198               "WrongInputName is not a valid name for any input. The indexes of the inputs"
    199                   + " are {input=0}");
    200     }
    201     int index = interpreter.getInputIndex("input");
    202     assertThat(index).isEqualTo(0);
    203   }
    204 
    205   @Test
    206   public void testGetOutputIndex() {
    207     Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
    208     try {
    209       interpreter.getOutputIndex("WrongOutputName");
    210       fail();
    211     } catch (IllegalArgumentException e) {
    212       assertThat(e)
    213           .hasMessageThat()
    214           .contains(
    215               "WrongOutputName is not a valid name for any output. The indexes of the outputs"
    216                   + " are {MobilenetV1/Predictions/Softmax=0}");
    217     }
    218     int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax");
    219     assertThat(index).isEqualTo(0);
    220   }
    221 }
    222