Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 package org.tensorflow.lite;
     17 
     18 import java.lang.reflect.Array;
     19 import java.nio.ByteBuffer;
     20 import java.nio.ByteOrder;
     21 import java.nio.MappedByteBuffer;
     22 import java.util.HashMap;
     23 import java.util.Map;
     24 
     25 /**
     26  * A wrapper wraps native interpreter and controls model execution.
     27  *
     28  * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
     29  * explicitly freed by invoking the {@link #close()} method when the {@code
     30  * NativeInterpreterWrapper} object is no longer needed.
     31  */
     32 final class NativeInterpreterWrapper implements AutoCloseable {
     33 
     34   NativeInterpreterWrapper(String modelPath) {
     35     errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     36     modelHandle = createModel(modelPath, errorHandle);
     37     interpreterHandle = createInterpreter(modelHandle, errorHandle);
     38   }
     39 
     40   /**
     41    * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The
     42    * MappedByteBuffer should not be modified after the construction of a {@code
     43    * NativeInterpreterWrapper}.
     44    */
     45   NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
     46     modelByteBuffer = mappedByteBuffer;
     47     errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     48     modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
     49     interpreterHandle = createInterpreter(modelHandle, errorHandle);
     50   }
     51 
     52   /** Releases resources associated with this {@code NativeInterpreterWrapper}. */
     53   @Override
     54   public void close() {
     55     delete(errorHandle, modelHandle, interpreterHandle);
     56     errorHandle = 0;
     57     modelHandle = 0;
     58     interpreterHandle = 0;
     59     modelByteBuffer = null;
     60     inputsIndexes = null;
     61     outputsIndexes = null;
     62   }
     63 
     64   /** Sets inputs, runs model inference and returns outputs. */
     65   Tensor[] run(Object[] inputs) {
     66     if (inputs == null || inputs.length == 0) {
     67       throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty.");
     68     }
     69     int[] dataTypes = new int[inputs.length];
     70     Object[] sizes = new Object[inputs.length];
     71     int[] numsOfBytes = new int[inputs.length];
     72     for (int i = 0; i < inputs.length; ++i) {
     73       DataType dataType = dataTypeOf(inputs[i]);
     74       dataTypes[i] = dataType.getNumber();
     75       if (dataType == DataType.BYTEBUFFER) {
     76         ByteBuffer buffer = (ByteBuffer) inputs[i];
     77         if (buffer.order() != ByteOrder.nativeOrder()) {
     78           throw new IllegalArgumentException(
     79               "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder().");
     80         }
     81         numsOfBytes[i] = buffer.limit();
     82         sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
     83       } else if (isNonEmptyArray(inputs[i])) {
     84         int[] dims = shapeOf(inputs[i]);
     85         sizes[i] = dims;
     86         numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
     87       } else {
     88         throw new IllegalArgumentException(
     89             String.format(
     90                 "%d-th element of the %d inputs is not an array or a ByteBuffer.",
     91                 i, inputs.length));
     92       }
     93     }
     94     long[] outputsHandles =
     95         run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs);
     96     if (outputsHandles == null || outputsHandles.length == 0) {
     97       throw new IllegalStateException("Interpreter has no outputs.");
     98     }
     99     Tensor[] outputs = new Tensor[outputsHandles.length];
    100     for (int i = 0; i < outputsHandles.length; ++i) {
    101       outputs[i] = Tensor.fromHandle(outputsHandles[i]);
    102     }
    103     return outputs;
    104   }
    105 
    106   private static native long[] run(
    107       long interpreterHandle,
    108       long errorHandle,
    109       Object[] sizes,
    110       int[] dtypes,
    111       int[] numsOfBytes,
    112       Object[] values);
    113 
    114   /** Resizes dimensions of a specific input. */
    115   void resizeInput(int idx, int[] dims) {
    116     resizeInput(interpreterHandle, errorHandle, idx, dims);
    117   }
    118 
    119   private static native void resizeInput(
    120       long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
    121 
    122   void setUseNNAPI(boolean useNNAPI) {
    123     useNNAPI(interpreterHandle, useNNAPI);
    124   }
    125 
    126   /** Gets index of an input given its name. */
    127   int getInputIndex(String name) {
    128     if (inputsIndexes == null) {
    129       String[] names = getInputNames(interpreterHandle);
    130       inputsIndexes = new HashMap<>();
    131       if (names != null) {
    132         for (int i = 0; i < names.length; ++i) {
    133           inputsIndexes.put(names[i], i);
    134         }
    135       }
    136     }
    137     if (inputsIndexes.containsKey(name)) {
    138       return inputsIndexes.get(name);
    139     } else {
    140       throw new IllegalArgumentException(
    141           String.format(
    142               "%s is not a valid name for any input. The indexes of the inputs are %s",
    143               name, inputsIndexes.toString()));
    144     }
    145   }
    146 
    147   /** Gets index of an output given its name. */
    148   int getOutputIndex(String name) {
    149     if (outputsIndexes == null) {
    150       String[] names = getOutputNames(interpreterHandle);
    151       outputsIndexes = new HashMap<>();
    152       if (names != null) {
    153         for (int i = 0; i < names.length; ++i) {
    154           outputsIndexes.put(names[i], i);
    155         }
    156       }
    157     }
    158     if (outputsIndexes.containsKey(name)) {
    159       return outputsIndexes.get(name);
    160     } else {
    161       throw new IllegalArgumentException(
    162           String.format(
    163               "%s is not a valid name for any output. The indexes of the outputs are %s",
    164               name, outputsIndexes.toString()));
    165     }
    166   }
    167 
    168   static int numElements(int[] shape) {
    169     if (shape == null) {
    170       return 0;
    171     }
    172     int n = 1;
    173     for (int i = 0; i < shape.length; i++) {
    174       n *= shape[i];
    175     }
    176     return n;
    177   }
    178 
    179   static boolean isNonEmptyArray(Object o) {
    180     return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
    181   }
    182 
    183   /** Returns the type of the data. */
    184   static DataType dataTypeOf(Object o) {
    185     if (o != null) {
    186       Class<?> c = o.getClass();
    187       while (c.isArray()) {
    188         c = c.getComponentType();
    189       }
    190       if (float.class.equals(c)) {
    191         return DataType.FLOAT32;
    192       } else if (int.class.equals(c)) {
    193         return DataType.INT32;
    194       } else if (byte.class.equals(c)) {
    195         return DataType.UINT8;
    196       } else if (long.class.equals(c)) {
    197         return DataType.INT64;
    198       } else if (ByteBuffer.class.isInstance(o)) {
    199         return DataType.BYTEBUFFER;
    200       }
    201     }
    202     throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName());
    203   }
    204 
    205   /** Returns the shape of an object as an int array. */
    206   static int[] shapeOf(Object o) {
    207     int size = numDimensions(o);
    208     int[] dimensions = new int[size];
    209     fillShape(o, 0, dimensions);
    210     return dimensions;
    211   }
    212 
    213   static int numDimensions(Object o) {
    214     if (o == null || !o.getClass().isArray()) {
    215       return 0;
    216     }
    217     if (Array.getLength(o) == 0) {
    218       throw new IllegalArgumentException("array lengths cannot be 0.");
    219     }
    220     return 1 + numDimensions(Array.get(o, 0));
    221   }
    222 
    223   static void fillShape(Object o, int dim, int[] shape) {
    224     if (shape == null || dim == shape.length) {
    225       return;
    226     }
    227     final int len = Array.getLength(o);
    228     if (shape[dim] == 0) {
    229       shape[dim] = len;
    230     } else if (shape[dim] != len) {
    231       throw new IllegalArgumentException(
    232           String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
    233     }
    234     for (int i = 0; i < len; ++i) {
    235       fillShape(Array.get(o, i), dim + 1, shape);
    236     }
    237   }
    238 
    239   private static final int ERROR_BUFFER_SIZE = 512;
    240 
    241   private long errorHandle;
    242 
    243   private long interpreterHandle;
    244 
    245   private long modelHandle;
    246 
    247   private int inputSize;
    248 
    249   private MappedByteBuffer modelByteBuffer;
    250 
    251   private Map<String, Integer> inputsIndexes;
    252 
    253   private Map<String, Integer> outputsIndexes;
    254 
    255   private static native String[] getInputNames(long interpreterHandle);
    256 
    257   private static native String[] getOutputNames(long interpreterHandle);
    258 
    259   private static native void useNNAPI(long interpreterHandle, boolean state);
    260 
    261   private static native long createErrorReporter(int size);
    262 
    263   private static native long createModel(String modelPathOrBuffer, long errorHandle);
    264 
    265   private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle);
    266 
    267   private static native long createInterpreter(long modelHandle, long errorHandle);
    268 
    269   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
    270 
    271   private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
    272 
    273   static {
    274     TensorFlowLite.init();
    275   }
    276 }
    277