1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.lite; 17 18 import java.lang.reflect.Array; 19 import java.nio.ByteBuffer; 20 import java.nio.ByteOrder; 21 import java.nio.MappedByteBuffer; 22 import java.util.HashMap; 23 import java.util.Map; 24 25 /** 26 * A wrapper wraps native interpreter and controls model execution. 27 * 28 * <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be 29 * explicitly freed by invoking the {@link #close()} method when the {@code 30 * NativeInterpreterWrapper} object is no longer needed. 31 */ 32 final class NativeInterpreterWrapper implements AutoCloseable { 33 34 NativeInterpreterWrapper(String modelPath) { 35 errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 36 modelHandle = createModel(modelPath, errorHandle); 37 interpreterHandle = createInterpreter(modelHandle, errorHandle); 38 } 39 40 /** 41 * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The 42 * MappedByteBuffer should not be modified after the construction of a {@code 43 * NativeInterpreterWrapper}. 44 */ 45 NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { 46 modelByteBuffer = mappedByteBuffer; 47 errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); 48 modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); 49 interpreterHandle = createInterpreter(modelHandle, errorHandle); 50 } 51 52 /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ 53 @Override 54 public void close() { 55 delete(errorHandle, modelHandle, interpreterHandle); 56 errorHandle = 0; 57 modelHandle = 0; 58 interpreterHandle = 0; 59 modelByteBuffer = null; 60 inputsIndexes = null; 61 outputsIndexes = null; 62 } 63 64 /** Sets inputs, runs model inference and returns outputs. */ 65 Tensor[] run(Object[] inputs) { 66 if (inputs == null || inputs.length == 0) { 67 throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty."); 68 } 69 int[] dataTypes = new int[inputs.length]; 70 Object[] sizes = new Object[inputs.length]; 71 int[] numsOfBytes = new int[inputs.length]; 72 for (int i = 0; i < inputs.length; ++i) { 73 DataType dataType = dataTypeOf(inputs[i]); 74 dataTypes[i] = dataType.getNumber(); 75 if (dataType == DataType.BYTEBUFFER) { 76 ByteBuffer buffer = (ByteBuffer) inputs[i]; 77 if (buffer.order() != ByteOrder.nativeOrder()) { 78 throw new IllegalArgumentException( 79 "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder()."); 80 } 81 numsOfBytes[i] = buffer.limit(); 82 sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); 83 } else if (isNonEmptyArray(inputs[i])) { 84 int[] dims = shapeOf(inputs[i]); 85 sizes[i] = dims; 86 numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); 87 } else { 88 throw new IllegalArgumentException( 89 String.format( 90 "%d-th element of the %d inputs is not an array or a ByteBuffer.", 91 i, inputs.length)); 92 } 93 } 94 long[] outputsHandles = 95 run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); 96 if (outputsHandles == null || outputsHandles.length == 0) { 97 throw new IllegalStateException("Interpreter has no outputs."); 98 } 99 Tensor[] outputs = new Tensor[outputsHandles.length]; 100 for (int i = 0; i < outputsHandles.length; ++i) { 101 outputs[i] = Tensor.fromHandle(outputsHandles[i]); 102 } 103 return outputs; 104 } 105 106 private static native long[] run( 107 long interpreterHandle, 108 long errorHandle, 109 Object[] sizes, 110 int[] dtypes, 111 int[] numsOfBytes, 112 Object[] values); 113 114 /** Resizes dimensions of a specific input. */ 115 void resizeInput(int idx, int[] dims) { 116 resizeInput(interpreterHandle, errorHandle, idx, dims); 117 } 118 119 private static native void resizeInput( 120 long interpreterHandle, long errorHandle, int inputIdx, int[] dims); 121 122 void setUseNNAPI(boolean useNNAPI) { 123 useNNAPI(interpreterHandle, useNNAPI); 124 } 125 126 /** Gets index of an input given its name. */ 127 int getInputIndex(String name) { 128 if (inputsIndexes == null) { 129 String[] names = getInputNames(interpreterHandle); 130 inputsIndexes = new HashMap<>(); 131 if (names != null) { 132 for (int i = 0; i < names.length; ++i) { 133 inputsIndexes.put(names[i], i); 134 } 135 } 136 } 137 if (inputsIndexes.containsKey(name)) { 138 return inputsIndexes.get(name); 139 } else { 140 throw new IllegalArgumentException( 141 String.format( 142 "%s is not a valid name for any input. The indexes of the inputs are %s", 143 name, inputsIndexes.toString())); 144 } 145 } 146 147 /** Gets index of an output given its name. */ 148 int getOutputIndex(String name) { 149 if (outputsIndexes == null) { 150 String[] names = getOutputNames(interpreterHandle); 151 outputsIndexes = new HashMap<>(); 152 if (names != null) { 153 for (int i = 0; i < names.length; ++i) { 154 outputsIndexes.put(names[i], i); 155 } 156 } 157 } 158 if (outputsIndexes.containsKey(name)) { 159 return outputsIndexes.get(name); 160 } else { 161 throw new IllegalArgumentException( 162 String.format( 163 "%s is not a valid name for any output. The indexes of the outputs are %s", 164 name, outputsIndexes.toString())); 165 } 166 } 167 168 static int numElements(int[] shape) { 169 if (shape == null) { 170 return 0; 171 } 172 int n = 1; 173 for (int i = 0; i < shape.length; i++) { 174 n *= shape[i]; 175 } 176 return n; 177 } 178 179 static boolean isNonEmptyArray(Object o) { 180 return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); 181 } 182 183 /** Returns the type of the data. */ 184 static DataType dataTypeOf(Object o) { 185 if (o != null) { 186 Class<?> c = o.getClass(); 187 while (c.isArray()) { 188 c = c.getComponentType(); 189 } 190 if (float.class.equals(c)) { 191 return DataType.FLOAT32; 192 } else if (int.class.equals(c)) { 193 return DataType.INT32; 194 } else if (byte.class.equals(c)) { 195 return DataType.UINT8; 196 } else if (long.class.equals(c)) { 197 return DataType.INT64; 198 } else if (ByteBuffer.class.isInstance(o)) { 199 return DataType.BYTEBUFFER; 200 } 201 } 202 throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName()); 203 } 204 205 /** Returns the shape of an object as an int array. */ 206 static int[] shapeOf(Object o) { 207 int size = numDimensions(o); 208 int[] dimensions = new int[size]; 209 fillShape(o, 0, dimensions); 210 return dimensions; 211 } 212 213 static int numDimensions(Object o) { 214 if (o == null || !o.getClass().isArray()) { 215 return 0; 216 } 217 if (Array.getLength(o) == 0) { 218 throw new IllegalArgumentException("array lengths cannot be 0."); 219 } 220 return 1 + numDimensions(Array.get(o, 0)); 221 } 222 223 static void fillShape(Object o, int dim, int[] shape) { 224 if (shape == null || dim == shape.length) { 225 return; 226 } 227 final int len = Array.getLength(o); 228 if (shape[dim] == 0) { 229 shape[dim] = len; 230 } else if (shape[dim] != len) { 231 throw new IllegalArgumentException( 232 String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); 233 } 234 for (int i = 0; i < len; ++i) { 235 fillShape(Array.get(o, i), dim + 1, shape); 236 } 237 } 238 239 private static final int ERROR_BUFFER_SIZE = 512; 240 241 private long errorHandle; 242 243 private long interpreterHandle; 244 245 private long modelHandle; 246 247 private int inputSize; 248 249 private MappedByteBuffer modelByteBuffer; 250 251 private Map<String, Integer> inputsIndexes; 252 253 private Map<String, Integer> outputsIndexes; 254 255 private static native String[] getInputNames(long interpreterHandle); 256 257 private static native String[] getOutputNames(long interpreterHandle); 258 259 private static native void useNNAPI(long interpreterHandle, boolean state); 260 261 private static native long createErrorReporter(int size); 262 263 private static native long createModel(String modelPathOrBuffer, long errorHandle); 264 265 private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); 266 267 private static native long createInterpreter(long modelHandle, long errorHandle); 268 269 private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); 270 271 private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); 272 273 static { 274 TensorFlowLite.init(); 275 } 276 } 277