Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 package org.tensorflow.lite;
     17 
     18 import java.nio.ByteBuffer;
     19 import java.nio.ByteOrder;
     20 import java.nio.MappedByteBuffer;
     21 import java.util.ArrayList;
     22 import java.util.HashMap;
     23 import java.util.List;
     24 import java.util.Map;
     25 
     26 /**
     27  * An internal wrapper that wraps native interpreter and controls model execution.
     28  *
     29  * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
     30  * explicitly freed by invoking the {@link #close()} method when the {@code
     31  * NativeInterpreterWrapper} object is no longer needed.
     32  */
     33 final class NativeInterpreterWrapper implements AutoCloseable {
     34 
     35   NativeInterpreterWrapper(String modelPath) {
     36     this(modelPath, /* options= */ null);
     37   }
     38 
     39   NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
     40     long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     41     long modelHandle = createModel(modelPath, errorHandle);
     42     init(errorHandle, modelHandle, options);
     43   }
     44 
     45   NativeInterpreterWrapper(ByteBuffer byteBuffer) {
     46     this(byteBuffer, /* options= */ null);
     47   }
     48 
     49   NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
     50     if (buffer == null
     51         || (!(buffer instanceof MappedByteBuffer)
     52             && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
     53       throw new IllegalArgumentException(
     54           "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct "
     55               + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content.");
     56     }
     57     this.modelByteBuffer = buffer;
     58     long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
     59     long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
     60     init(errorHandle, modelHandle, options);
     61   }
     62 
     63   private void init(long errorHandle, long modelHandle, Interpreter.Options options) {
     64     if (options == null) {
     65       options = new Interpreter.Options();
     66     }
     67     this.errorHandle = errorHandle;
     68     this.modelHandle = modelHandle;
     69     this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
     70     this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
     71     this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
     72     if (options.useNNAPI != null) {
     73       setUseNNAPI(options.useNNAPI.booleanValue());
     74     }
     75     if (options.allowFp16PrecisionForFp32 != null) {
     76       allowFp16PrecisionForFp32(
     77           interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
     78     }
     79     if (options.allowBufferHandleOutput != null) {
     80       allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
     81     }
     82     for (Delegate delegate : options.delegates) {
     83       applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
     84       delegates.add(delegate);
     85     }
     86     allocateTensors(interpreterHandle, errorHandle);
     87     this.isMemoryAllocated = true;
     88   }
     89 
     90   /** Releases resources associated with this {@code NativeInterpreterWrapper}. */
     91   @Override
     92   public void close() {
     93     // Close the tensors first as they may reference the native interpreter.
     94     for (int i = 0; i < inputTensors.length; ++i) {
     95       if (inputTensors[i] != null) {
     96         inputTensors[i].close();
     97         inputTensors[i] = null;
     98       }
     99     }
    100     for (int i = 0; i < outputTensors.length; ++i) {
    101       if (outputTensors[i] != null) {
    102         outputTensors[i].close();
    103         outputTensors[i] = null;
    104       }
    105     }
    106     delete(errorHandle, modelHandle, interpreterHandle);
    107     errorHandle = 0;
    108     modelHandle = 0;
    109     interpreterHandle = 0;
    110     modelByteBuffer = null;
    111     inputsIndexes = null;
    112     outputsIndexes = null;
    113     isMemoryAllocated = false;
    114     delegates.clear();
    115   }
    116 
    117   /** Sets inputs, runs model inference and returns outputs. */
    118   void run(Object[] inputs, Map<Integer, Object> outputs) {
    119     inferenceDurationNanoseconds = -1;
    120     if (inputs == null || inputs.length == 0) {
    121       throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
    122     }
    123     if (outputs == null || outputs.isEmpty()) {
    124       throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
    125     }
    126 
    127     // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
    128     // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
    129     // flush all resizes, avoiding redundant allocations.
    130     for (int i = 0; i < inputs.length; ++i) {
    131       Tensor tensor = getInputTensor(i);
    132       int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
    133       if (newShape != null) {
    134         resizeInput(i, newShape);
    135       }
    136     }
    137 
    138     boolean needsAllocation = !isMemoryAllocated;
    139     if (needsAllocation) {
    140       allocateTensors(interpreterHandle, errorHandle);
    141       isMemoryAllocated = true;
    142     }
    143 
    144     for (int i = 0; i < inputs.length; ++i) {
    145       getInputTensor(i).setTo(inputs[i]);
    146     }
    147 
    148     long inferenceStartNanos = System.nanoTime();
    149     run(interpreterHandle, errorHandle);
    150     long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
    151 
    152     // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
    153     if (needsAllocation) {
    154       for (int i = 0; i < outputTensors.length; ++i) {
    155         if (outputTensors[i] != null) {
    156           outputTensors[i].refreshShape();
    157         }
    158       }
    159     }
    160     for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
    161       getOutputTensor(output.getKey()).copyTo(output.getValue());
    162     }
    163 
    164     // Only set if the entire operation succeeds.
    165     this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
    166   }
    167 
    168   private static native boolean run(long interpreterHandle, long errorHandle);
    169 
    170   /** Resizes dimensions of a specific input. */
    171   void resizeInput(int idx, int[] dims) {
    172     if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
    173       isMemoryAllocated = false;
    174       if (inputTensors[idx] != null) {
    175         inputTensors[idx].refreshShape();
    176       }
    177     }
    178   }
    179 
    180   private static native boolean resizeInput(
    181       long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
    182 
    183   void setUseNNAPI(boolean useNNAPI) {
    184     useNNAPI(interpreterHandle, useNNAPI);
    185   }
    186 
    187   void setNumThreads(int numThreads) {
    188     numThreads(interpreterHandle, numThreads);
    189   }
    190 
    191   void modifyGraphWithDelegate(Delegate delegate) {
    192     applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
    193     delegates.add(delegate);
    194   }
    195 
    196   /** Gets index of an input given its name. */
    197   int getInputIndex(String name) {
    198     if (inputsIndexes == null) {
    199       String[] names = getInputNames(interpreterHandle);
    200       inputsIndexes = new HashMap<>();
    201       if (names != null) {
    202         for (int i = 0; i < names.length; ++i) {
    203           inputsIndexes.put(names[i], i);
    204         }
    205       }
    206     }
    207     if (inputsIndexes.containsKey(name)) {
    208       return inputsIndexes.get(name);
    209     } else {
    210       throw new IllegalArgumentException(
    211           String.format(
    212               "Input error: '%s' is not a valid name for any input. Names of inputs and their "
    213                   + "indexes are %s",
    214               name, inputsIndexes.toString()));
    215     }
    216   }
    217 
    218   /** Gets index of an output given its name. */
    219   int getOutputIndex(String name) {
    220     if (outputsIndexes == null) {
    221       String[] names = getOutputNames(interpreterHandle);
    222       outputsIndexes = new HashMap<>();
    223       if (names != null) {
    224         for (int i = 0; i < names.length; ++i) {
    225           outputsIndexes.put(names[i], i);
    226         }
    227       }
    228     }
    229     if (outputsIndexes.containsKey(name)) {
    230       return outputsIndexes.get(name);
    231     } else {
    232       throw new IllegalArgumentException(
    233           String.format(
    234               "Input error: '%s' is not a valid name for any output. Names of outputs and their "
    235                   + "indexes are %s",
    236               name, outputsIndexes.toString()));
    237     }
    238   }
    239 
    240   /**
    241    * Gets the last inference duration in nanoseconds. It returns null if there is no previous
    242    * inference run or the last inference run failed.
    243    */
    244   Long getLastNativeInferenceDurationNanoseconds() {
    245     return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
    246   }
    247 
    248   /**
    249    * Gets the quantization zero point of an output.
    250    *
    251    * @throws IllegalArgumentException if the output index is invalid.
    252    */
    253   int getOutputQuantizationZeroPoint(int index) {
    254     return getOutputQuantizationZeroPoint(interpreterHandle, index);
    255   }
    256 
    257   /**
    258    * Gets the quantization scale of an output.
    259    *
    260    * @throws IllegalArgumentException if the output index is invalid.
    261    */
    262   float getOutputQuantizationScale(int index) {
    263     return getOutputQuantizationScale(interpreterHandle, index);
    264   }
    265 
    266   /** Gets the number of input tensors. */
    267   int getInputTensorCount() {
    268     return inputTensors.length;
    269   }
    270 
    271   /**
    272    * Gets the input {@link Tensor} for the provided input index.
    273    *
    274    * @throws IllegalArgumentException if the input index is invalid.
    275    */
    276   Tensor getInputTensor(int index) {
    277     if (index < 0 || index >= inputTensors.length) {
    278       throw new IllegalArgumentException("Invalid input Tensor index: " + index);
    279     }
    280     Tensor inputTensor = inputTensors[index];
    281     if (inputTensor == null) {
    282       inputTensor =
    283           inputTensors[index] =
    284               Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index));
    285     }
    286     return inputTensor;
    287   }
    288 
    289   /** Gets the number of output tensors. */
    290   int getOutputTensorCount() {
    291     return outputTensors.length;
    292   }
    293 
    294   /**
    295    * Gets the output {@link Tensor} for the provided output index.
    296    *
    297    * @throws IllegalArgumentException if the output index is invalid.
    298    */
    299   Tensor getOutputTensor(int index) {
    300     if (index < 0 || index >= outputTensors.length) {
    301       throw new IllegalArgumentException("Invalid output Tensor index: " + index);
    302     }
    303     Tensor outputTensor = outputTensors[index];
    304     if (outputTensor == null) {
    305       outputTensor =
    306           outputTensors[index] =
    307               Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index));
    308     }
    309     return outputTensor;
    310   }
    311 
    312   private static native int getOutputDataType(long interpreterHandle, int outputIdx);
    313 
    314   private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx);
    315 
    316   private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx);
    317 
    318   private static final int ERROR_BUFFER_SIZE = 512;
    319 
    320   private long errorHandle;
    321 
    322   private long interpreterHandle;
    323 
    324   private long modelHandle;
    325 
    326   private long inferenceDurationNanoseconds = -1;
    327 
    328   private ByteBuffer modelByteBuffer;
    329 
    330   // Lazily constructed maps of input and output names to input and output Tensor indexes.
    331   private Map<String, Integer> inputsIndexes;
    332   private Map<String, Integer> outputsIndexes;
    333 
    334   // Lazily constructed and populated arrays of input and output Tensor wrappers.
    335   private Tensor[] inputTensors;
    336   private Tensor[] outputTensors;
    337 
    338   private boolean isMemoryAllocated = false;
    339 
    340   // As the Java Delegate owns the native delegate instance, we keep a strong ref to any injected
    341   // delegates for safety.
    342   private final List<Delegate> delegates = new ArrayList<>();
    343 
    344   private static native long allocateTensors(long interpreterHandle, long errorHandle);
    345 
    346   private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
    347 
    348   private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
    349 
    350   private static native int getInputCount(long interpreterHandle);
    351 
    352   private static native int getOutputCount(long interpreterHandle);
    353 
    354   private static native String[] getInputNames(long interpreterHandle);
    355 
    356   private static native String[] getOutputNames(long interpreterHandle);
    357 
    358   private static native void useNNAPI(long interpreterHandle, boolean state);
    359 
    360   private static native void numThreads(long interpreterHandle, int numThreads);
    361 
    362   private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
    363 
    364   private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
    365 
    366   private static native long createErrorReporter(int size);
    367 
    368   private static native long createModel(String modelPathOrBuffer, long errorHandle);
    369 
    370   private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);
    371 
    372   private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads);
    373 
    374   private static native void applyDelegate(
    375       long interpreterHandle, long errorHandle, long delegateHandle);
    376 
    377   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
    378 
    379   static {
    380     TensorFlowLite.init();
    381   }
    382 }
    383