Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
     17 #define TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 
     23 namespace tensorflow {
     24 namespace functor {
     25 
     26 template <typename Device, typename T>
     27 struct QuantizeAndDequantizeOneScaleFunctor {
     28   void operator()(const Device& d, typename TTypes<T>::ConstVec input,
     29                   bool signed_input, int num_bits, bool range_given,
     30                   Tensor* input_min_tensor, Tensor* input_max_tensor,
     31                   typename TTypes<T>::Vec out);
     32 };
     33 
     34 // The implementation below runs on both CPU and GPU.
     35 template <typename Device, typename T>
     36 struct QuantizeAndDequantizeOneScaleImpl {
     37   static void Compute(const Device& d, typename TTypes<T>::ConstVec input,
     38                       bool signed_input, int num_bits, bool range_given,
     39                       Tensor* input_min_tensor, Tensor* input_max_tensor,
     40                       typename TTypes<T>::Vec out) {
     41     T min_range;
     42     T max_range;
     43     auto input_min = input_min_tensor->scalar<T>();
     44     auto input_max = input_max_tensor->scalar<T>();
     45     if (!range_given) {
     46       input_min.device(d) = input.minimum();
     47       input_max.device(d) = input.maximum();
     48     }
     49     d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
     50     d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
     51 
     52     // Make sure the range is symmetric for signed quantization, or start from
     53     // 0 for unsigned quantization.
     54     max_range = std::max(std::abs(max_range), std::abs(min_range));
     55 
     56     // If both min and max are 0, then the output should be just 0.
     57     if (max_range == 0) {
     58       out.device(d) = input.constant(T(0));
     59       return;
     60     }
     61 
     62     if (signed_input) {
     63       min_range = -max_range;
     64 
     65       // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
     66       // example, if it is 8 bits, we have the range [-127, 127]. So for input
     67       // range of [-x, x], the scale should be 254/(2*x).
     68       T scale = static_cast<T>((uint64_t{1} << (num_bits - 1)) - 1) / max_range;
     69       T inverse_scale = T(1.0) / scale;
     70       if (range_given) {
     71         out.device(d) =
     72             ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) *
     73                  scale +
     74              T(0.5))
     75                     .floor() *
     76                 inverse_scale +
     77             min_range;
     78       } else {
     79         // No need to compare with min and max as they are measured from the
     80         // tensor.
     81         out.device(d) =
     82             ((input - min_range) * scale + T(0.5)).floor() * inverse_scale +
     83             min_range;
     84       }
     85     } else {
     86       min_range = 0;
     87       // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255].
     88       // If the input range is [0, x], then the scale is x/255 instead of 254 as
     89       // in the case above.
     90       T scale = static_cast<T>((uint64_t{1} << num_bits) - 1) / max_range;
     91       T inverse_scale = 1.0 / scale;
     92       if (range_given) {
     93         out.device(d) =
     94             ((input.cwiseMin(max_range).cwiseMax(min_range)) * scale + T(0.5))
     95                 .floor() *
     96             inverse_scale;
     97       } else {
     98         // No need to compare with min and max as they are measured from the
     99         // tensor.
    100         out.device(d) = (input * scale + T(0.5)).floor() * inverse_scale;
    101       }
    102     }
    103   }
    104 };
    105 
    106 }  // end of namespace functor
    107 }  // end of namespace tensorflow
    108 
    109 #endif  // TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_
    110