1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.nio.ByteBuffer; 19 import java.nio.ByteOrder; 20 import java.nio.MappedByteBuffer; 21 import java.util.ArrayList; 22 import java.util.HashMap; 23 import java.util.List; 24 import java.util.Map; 25 26 /** 27 * An internal wrapper that wraps native interpreter and controls model execution. 28 * 29 * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be 30 * explicitly freed by invoking the {@link #close()} method when the {@code 31 * NativeInterpreterWrapper} object is no longer needed. 32 */ 33 final class NativeInterpreterWrapper implements AutoCloseable { 34 35 NativeInterpreterWrapper(String modelPath) { 36 this(modelPath, /* options= */ null); 37 } 38 39 NativeInterpreterWrapper(String modelPath, Interpreter.Options options) { 40 long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 41 long modelHandle = createModel(modelPath, errorHandle); 42 init(errorHandle, modelHandle, options); 43 } 44 45 NativeInterpreterWrapper(ByteBuffer byteBuffer) { 46 this(byteBuffer, /* options= */ null); 47 } 48 49 NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) { 50 if (buffer == null 51 || (!(buffer instanceof MappedByteBuffer) 52 && (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) { 53 throw new IllegalArgumentException( 54 "Model ByteBuffer should be either a MappedByteBuffer of the model file, or a direct " 55 + "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content."); 56 } 57 this.modelByteBuffer = buffer; 58 long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 59 long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); 60 init(errorHandle, modelHandle, options); 61 } 62 63 private void init(long errorHandle, long modelHandle, Interpreter.Options options) { 64 if (options == null) { 65 options = new Interpreter.Options(); 66 } 67 this.errorHandle = errorHandle; 68 this.modelHandle = modelHandle; 69 this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads); 70 this.inputTensors = new Tensor[getInputCount(interpreterHandle)]; 71 this.outputTensors = new Tensor[getOutputCount(interpreterHandle)]; 72 if (options.useNNAPI != null) { 73 setUseNNAPI(options.useNNAPI.booleanValue()); 74 } 75 if (options.allowFp16PrecisionForFp32 != null) { 76 allowFp16PrecisionForFp32( 77 interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue()); 78 } 79 if (options.allowBufferHandleOutput != null) { 80 allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue()); 81 } 82 for (Delegate delegate : options.delegates) { 83 applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle()); 84 delegates.add(delegate); 85 } 86 allocateTensors(interpreterHandle, errorHandle); 87 this.isMemoryAllocated = true; 88 } 89 90 /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ 91 @Override 92 public void close() { 93 // Close the tensors first as they may reference the native interpreter. 94 for (int i = 0; i < inputTensors.length; ++i) { 95 if (inputTensors[i] != null) { 96 inputTensors[i].close(); 97 inputTensors[i] = null; 98 } 99 } 100 for (int i = 0; i < outputTensors.length; ++i) { 101 if (outputTensors[i] != null) { 102 outputTensors[i].close(); 103 outputTensors[i] = null; 104 } 105 } 106 delete(errorHandle, modelHandle, interpreterHandle); 107 errorHandle = 0; 108 modelHandle = 0; 109 interpreterHandle = 0; 110 modelByteBuffer = null; 111 inputsIndexes = null; 112 outputsIndexes = null; 113 isMemoryAllocated = false; 114 delegates.clear(); 115 } 116 117 /** Sets inputs, runs model inference and returns outputs. */ 118 void run(Object[] inputs, Map<Integer, Object> outputs) { 119 inferenceDurationNanoseconds = -1; 120 if (inputs == null || inputs.length == 0) { 121 throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); 122 } 123 if (outputs == null || outputs.isEmpty()) { 124 throw new IllegalArgumentException("Input error: Outputs should not be null or empty."); 125 } 126 127 // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs. 128 // Rather than forcing an immediate resize + allocation if an input's shape differs, we first 129 // flush all resizes, avoiding redundant allocations. 130 for (int i = 0; i < inputs.length; ++i) { 131 Tensor tensor = getInputTensor(i); 132 int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]); 133 if (newShape != null) { 134 resizeInput(i, newShape); 135 } 136 } 137 138 boolean needsAllocation = !isMemoryAllocated; 139 if (needsAllocation) { 140 allocateTensors(interpreterHandle, errorHandle); 141 isMemoryAllocated = true; 142 } 143 144 for (int i = 0; i < inputs.length; ++i) { 145 getInputTensor(i).setTo(inputs[i]); 146 } 147 148 long inferenceStartNanos = System.nanoTime(); 149 run(interpreterHandle, errorHandle); 150 long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; 151 152 // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes. 153 if (needsAllocation) { 154 for (int i = 0; i < outputTensors.length; ++i) { 155 if (outputTensors[i] != null) { 156 outputTensors[i].refreshShape(); 157 } 158 } 159 } 160 for (Map.Entry<Integer, Object> output : outputs.entrySet()) { 161 getOutputTensor(output.getKey()).copyTo(output.getValue()); 162 } 163 164 // Only set if the entire operation succeeds. 165 this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; 166 } 167 168 private static native boolean run(long interpreterHandle, long errorHandle); 169 170 /** Resizes dimensions of a specific input. */ 171 void resizeInput(int idx, int[] dims) { 172 if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { 173 isMemoryAllocated = false; 174 if (inputTensors[idx] != null) { 175 inputTensors[idx].refreshShape(); 176 } 177 } 178 } 179 180 private static native boolean resizeInput( 181 long interpreterHandle, long errorHandle, int inputIdx, int[] dims); 182 183 void setUseNNAPI(boolean useNNAPI) { 184 useNNAPI(interpreterHandle, useNNAPI); 185 } 186 187 void setNumThreads(int numThreads) { 188 numThreads(interpreterHandle, numThreads); 189 } 190 191 void modifyGraphWithDelegate(Delegate delegate) { 192 applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle()); 193 delegates.add(delegate); 194 } 195 196 /** Gets index of an input given its name. */ 197 int getInputIndex(String name) { 198 if (inputsIndexes == null) { 199 String[] names = getInputNames(interpreterHandle); 200 inputsIndexes = new HashMap<>(); 201 if (names != null) { 202 for (int i = 0; i < names.length; ++i) { 203 inputsIndexes.put(names[i], i); 204 } 205 } 206 } 207 if (inputsIndexes.containsKey(name)) { 208 return inputsIndexes.get(name); 209 } else { 210 throw new IllegalArgumentException( 211 String.format( 212 "Input error: '%s' is not a valid name for any input. Names of inputs and their " 213 + "indexes are %s", 214 name, inputsIndexes.toString())); 215 } 216 } 217 218 /** Gets index of an output given its name. */ 219 int getOutputIndex(String name) { 220 if (outputsIndexes == null) { 221 String[] names = getOutputNames(interpreterHandle); 222 outputsIndexes = new HashMap<>(); 223 if (names != null) { 224 for (int i = 0; i < names.length; ++i) { 225 outputsIndexes.put(names[i], i); 226 } 227 } 228 } 229 if (outputsIndexes.containsKey(name)) { 230 return outputsIndexes.get(name); 231 } else { 232 throw new IllegalArgumentException( 233 String.format( 234 "Input error: '%s' is not a valid name for any output. Names of outputs and their " 235 + "indexes are %s", 236 name, outputsIndexes.toString())); 237 } 238 } 239 240 /** 241 * Gets the last inference duration in nanoseconds. It returns null if there is no previous 242 * inference run or the last inference run failed. 243 */ 244 Long getLastNativeInferenceDurationNanoseconds() { 245 return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds; 246 } 247 248 /** 249 * Gets the quantization zero point of an output. 250 * 251 * @throws IllegalArgumentException if the output index is invalid. 252 */ 253 int getOutputQuantizationZeroPoint(int index) { 254 return getOutputQuantizationZeroPoint(interpreterHandle, index); 255 } 256 257 /** 258 * Gets the quantization scale of an output. 259 * 260 * @throws IllegalArgumentException if the output index is invalid. 261 */ 262 float getOutputQuantizationScale(int index) { 263 return getOutputQuantizationScale(interpreterHandle, index); 264 } 265 266 /** Gets the number of input tensors. */ 267 int getInputTensorCount() { 268 return inputTensors.length; 269 } 270 271 /** 272 * Gets the input {@link Tensor} for the provided input index. 273 * 274 * @throws IllegalArgumentException if the input index is invalid. 275 */ 276 Tensor getInputTensor(int index) { 277 if (index < 0 || index >= inputTensors.length) { 278 throw new IllegalArgumentException("Invalid input Tensor index: " + index); 279 } 280 Tensor inputTensor = inputTensors[index]; 281 if (inputTensor == null) { 282 inputTensor = 283 inputTensors[index] = 284 Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index)); 285 } 286 return inputTensor; 287 } 288 289 /** Gets the number of output tensors. */ 290 int getOutputTensorCount() { 291 return outputTensors.length; 292 } 293 294 /** 295 * Gets the output {@link Tensor} for the provided output index. 296 * 297 * @throws IllegalArgumentException if the output index is invalid. 298 */ 299 Tensor getOutputTensor(int index) { 300 if (index < 0 || index >= outputTensors.length) { 301 throw new IllegalArgumentException("Invalid output Tensor index: " + index); 302 } 303 Tensor outputTensor = outputTensors[index]; 304 if (outputTensor == null) { 305 outputTensor = 306 outputTensors[index] = 307 Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index)); 308 } 309 return outputTensor; 310 } 311 312 private static native int getOutputDataType(long interpreterHandle, int outputIdx); 313 314 private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); 315 316 private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); 317 318 private static final int ERROR_BUFFER_SIZE = 512; 319 320 private long errorHandle; 321 322 private long interpreterHandle; 323 324 private long modelHandle; 325 326 private long inferenceDurationNanoseconds = -1; 327 328 private ByteBuffer modelByteBuffer; 329 330 // Lazily constructed maps of input and output names to input and output Tensor indexes. 331 private Map<String, Integer> inputsIndexes; 332 private Map<String, Integer> outputsIndexes; 333 334 // Lazily constructed and populated arrays of input and output Tensor wrappers. 335 private Tensor[] inputTensors; 336 private Tensor[] outputTensors; 337 338 private boolean isMemoryAllocated = false; 339 340 // As the Java Delegate owns the native delegate instance, we keep a strong ref to any injected 341 // delegates for safety. 342 private final List<Delegate> delegates = new ArrayList<>(); 343 344 private static native long allocateTensors(long interpreterHandle, long errorHandle); 345 346 private static native int getInputTensorIndex(long interpreterHandle, int inputIdx); 347 348 private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx); 349 350 private static native int getInputCount(long interpreterHandle); 351 352 private static native int getOutputCount(long interpreterHandle); 353 354 private static native String[] getInputNames(long interpreterHandle); 355 356 private static native String[] getOutputNames(long interpreterHandle); 357 358 private static native void useNNAPI(long interpreterHandle, boolean state); 359 360 private static native void numThreads(long interpreterHandle, int numThreads); 361 362 private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow); 363 364 private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); 365 366 private static native long createErrorReporter(int size); 367 368 private static native long createModel(String modelPathOrBuffer, long errorHandle); 369 370 private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle); 371 372 private static native long createInterpreter(long modelHandle, long errorHandle, int numThreads); 373 374 private static native void applyDelegate( 375 long interpreterHandle, long errorHandle, long delegateHandle); 376 377 private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); 378 379 static { 380 TensorFlowLite.init(); 381 } 382 } 383