1 # How to Quantize Neural Networks with TensorFlow 2 3 When modern neural networks were being developed, the biggest challenge was 4 getting them to work at all! That meant that accuracy and speed during training 5 were the top priorities. Using floating point arithmetic was the easiest way to 6 preserve accuracy, and GPUs were well-equipped to accelerate those calculations, 7 so it's natural that not much attention was paid to other numerical formats. 8 9 These days, we actually have a lot of models being deployed in commercial 10 applications. The computation demands of training grow with the number of 11 researchers, but the cycles needed for inference expand in proportion to users. 12 That means pure inference efficiency has become a burning issue for a lot of 13 teams. 14 15 That is where quantization comes in. It's an umbrella term that covers a lot of 16 different techniques to store numbers and perform calculations on them in more 17 compact formats than 32-bit floating point. I am going to focus on eight-bit 18 fixed point, for reasons I'll go into more detail on later. 19 20 [TOC] 21 22 ## Why does Quantization Work? 23 24 Training neural networks is done by applying many tiny nudges to the weights, 25 and these small increments typically need floating point precision to work 26 (though there are research efforts to use quantized representations here too). 27 28 Taking a pre-trained model and running inference is very different. One of the 29 magical qualities of deep networks is that they tend to cope very well with high 30 levels of noise in their inputs. If you think about recognizing an object in a 31 photo you've just taken, the network has to ignore all the CCD noise, lighting 32 changes, and other non-essential differences between it and the training 33 examples it's seen before, and focus on the important similarities instead. This 34 ability means that they seem to treat low-precision calculations as just another 35 source of noise, and still produce accurate results even with numerical formats 36 that hold less information. 37 38 ## Why Quantize? 39 40 Neural network models can take up a lot of space on disk, with the original 41 AlexNet being over 200 MB in float format for example. Almost all of that size 42 is taken up with the weights for the neural connections, since there are often 43 many millions of these in a single model. Because they're all slightly different 44 floating point numbers, simple compression formats like zip don't compress them 45 well. They are arranged in large layers though, and within each layer the 46 weights tend to be normally distributed within a certain range, for example -3.0 47 to 6.0. 48 49 The simplest motivation for quantization is to shrink file sizes by storing the 50 min and max for each layer, and then compressing each float value to an 51 eight-bit integer representing the closest real number in a linear set of 256 52 within the range. For example with the -3.0 to 6.0 range, a 0 byte would 53 represent -3.0, a 255 would stand for 6.0, and 128 would represent about 1.5. 54 I'll go into the exact calculations later, since there's some subtleties, but 55 this means you can get the benefit of a file on disk that's shrunk by 75%, and 56 then convert back to float after loading so that your existing floating-point 57 code can work without any changes. 58 59 Another reason to quantize is to reduce the computational resources you need to 60 do the inference calculations, by running them entirely with eight-bit inputs 61 and outputs. This is a lot more difficult since it requires changes everywhere 62 you do calculations, but offers a lot of potential rewards. Fetching eight-bit 63 values only requires 25% of the memory bandwidth of floats, so you'll make much 64 better use of caches and avoid bottlenecking on RAM access. You can also 65 typically use SIMD operations that do many more operations per clock cycle. In 66 some case you'll have a DSP chip available that can accelerate eight-bit 67 calculations too, which can offer a lot of advantages. 68 69 Moving calculations over to eight bit will help you run your models faster, and 70 use less power (which is especially important on mobile devices). It also opens 71 the door to a lot of embedded systems that can't run floating point code 72 efficiently, so it can enable a lot of applications in the IoT world. 73 74 ## Why Not Train in Lower Precision Directly? 75 76 There have been some experiments training at lower bit depths, but the results 77 seem to indicate that you need higher than eight bit to handle the back 78 propagation and gradients. That makes implementing the training more 79 complicated, and so starting with inference made sense. We also already have a 80 lot of float models already that we use and know well, so being able to convert 81 them directly is very convenient. 82 83 ## How Can You Quantize Your Models? 84 85 TensorFlow has production-grade support for eight-bit calculations built in. It 86 also has a process for converting many models trained in floating-point over to 87 equivalent graphs using quantized calculations for inference. For example, 88 here's how you can translate the latest GoogLeNet model into a version that uses 89 eight-bit computations: 90 91 ```sh 92 curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" | 93 tar -C tensorflow/examples/label_image/data -xz 94 bazel build tensorflow/tools/graph_transforms:transform_graph 95 bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ 96 --in_graph=tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb \ 97 --out_graph=/tmp/quantized_graph.pb \ 98 --inputs=input \ 99 --outputs=InceptionV3/Predictions/Reshape_1 \ 100 --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,299,299,3") 101 remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) 102 fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes 103 strip_unused_nodes sort_by_execution_order' 104 ``` 105 106 This will produce a new model that runs the same operations as the original, but 107 with eight bit calculations internally, and all weights quantized as well. If 108 you look at the file size, you'll see it's about a quarter of the original (23MB 109 versus 91MB). You can still run this model using exactly the same inputs and 110 outputs though, and you should get equivalent results. Here's an example: 111 112 ```sh 113 bazel build tensorflow/examples/label_image:label_image 114 bazel-bin/tensorflow/examples/label_image/label_image \ 115 --graph=/tmp/quantized_graph.pb \ 116 ``` 117 118 You'll see that this runs the newly-quantized graph, and outputs a very similar 119 answer to the original. 120 121 You can run the same process on your own models saved out as GraphDefs, with the 122 input and output names adapted to those your network requires. I recommend that 123 you run them through the freeze_graph script first, to convert checkpoints into 124 constants stored in the file. 125 126 ## How Does the Quantization Process Work? 127 128 We've implemented quantization by writing equivalent eight-bit versions of 129 operations that are commonly used during inference. These include convolution, 130 matrix multiplication, activation functions, pooling operations and 131 concatenation. The conversion script first replaces all the individual ops it 132 knows about with quantized equivalents. These are small sub-graphs that have 133 conversion functions before and after to move the data between float and 134 eight-bit. Below is an example of what they look like. First here's the original 135 Relu operation, with float inputs and outputs: 136 137 ![Relu Diagram](https://www.tensorflow.org/images/quantization0.png) 138 139 Then, this is the equivalent converted subgraph, still with float inputs and 140 outputs, but with internal conversions so the calculations are done in eight 141 bit. 142 143 ![Converted Diagram](https://www.tensorflow.org/images/quantization1.png) 144 145 The min and max operations actually look at the values in the input float 146 tensor, and then feeds them into the Dequantize operation that converts the 147 tensor into eight-bits. There are more details on how the quantized representation 148 works later on. 149 150 Once the individual operations have been converted, the next stage is to remove 151 unnecessary conversions to and from float. If there are consecutive sequences of 152 operations that all have float equivalents, then there will be a lot of adjacent 153 Dequantize/Quantize ops. This stage spots that pattern, recognizes that they 154 cancel each other out, and removes them, like this: 155 156 ![Stripping Diagram](https://www.tensorflow.org/images/quantization2.png) 157 158 Applied on a large scale to models where all of the operations have quantized 159 equivalents, this gives a graph where all of the tensor calculations are done in 160 eight bit, without having to convert to float. 161 162 ## What Representation is Used for Quantized Tensors? 163 164 We approach converting floating-point arrays of numbers into eight-bit 165 representations as a compression problem. We know that the weights and 166 activation tensors in trained neural network models tend to have values that are 167 distributed across comparatively small ranges (for example you might have -15 to 168 +15 for weights, -500 to 1000 for activations on an image model, though the 169 exact numbers will vary). We also know from experiment that neural nets tend to 170 be very robust in the face of noise, and so the noise-like error produced by 171 quantizing down to a small set of values will not hurt the precision of the 172 overall results very much. We also want to pick a representation that's easy to 173 perform calculations on, especially the large matrix multiplications that form 174 the bulk of the work that's needed to run a model. 175 176 These led us to pick a representation that has two floats to store the overall 177 minimum and maximum values that are represented by the lowest and highest 178 quantized value. Each entry in the quantized array represents a float value in 179 that range, distributed linearly between the minimum and maximum. For example, 180 if we have minimum = -10.0, and maximum = 30.0f, and an eight-bit array, here's 181 what the quantized values represent: 182 183 ``` 184 Quantized | Float 185 --------- | ----- 186 0 | -10.0 187 255 | 30.0 188 128 | 10.0 189 ``` 190 191 The advantages of this format are that it can represent arbitrary magnitudes of 192 ranges, they don't have to be symmetrical, it can represent signed and unsigned 193 values, and the linear spread makes doing multiplications straightforward. There 194 are alternatives like [Song Han's code books](http://arxiv.org/pdf/1510.00149.pdf) 195 that can use lower bit depths by non-linearly distributing the float values 196 across the representation, but these tend to be more expensive to calculate on. 197 198 The advantage of having a strong and clear definition of the quantized format is 199 that it's always possible to convert back and forth from float for operations 200 that aren't quantization-ready, or to inspect the tensors for debugging 201 purposes. One implementation detail in TensorFlow that we're hoping to improve 202 in the future is that the minimum and maximum float values need to be passed as 203 separate tensors to the one holding the quantized values, so graphs can get a 204 bit dense! 205 206 The nice thing about the minimum and maximum ranges is that they can often be 207 pre-calculated. Weight parameters are constants known at load time, so their 208 ranges can also be stored as constants. We often know the ranges for inputs (for 209 examples images are usually RGB values in the range 0.0 to 255.0), and many 210 activation functions have known ranges too. This can avoid having to analyze the 211 outputs of an operation to determine the range, which we need to do for math ops 212 like convolution or matrix multiplication which produce 32-bit accumulated 213 results from 8-bit inputs. 214 215 ## What's Next? 216 217 We've found that we can get extremely good performance on mobile and embedded 218 devices by using eight-bit arithmetic rather than floating-point. You can see 219 the framework we use to optimize matrix multiplications at 220 [gemmlowp](https://github.com/google/gemmlowp). We still need to apply all the 221 lessons we've learned to the TensorFlow ops to get maximum performance on 222 mobile, but we're actively working on that. Right now, this quantized 223 implementation is a reasonably fast and accurate reference implementation that 224 we're hoping will enable wider support for our eight-bit models on a wider 225 variety of devices. We also hope that this demonstration will encourage the 226 community to explore what's possible with low-precision neural networks. 227