Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
     23 
     24 #include "tensorflow/core/framework/op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/type_traits.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 
     31 namespace tensorflow {
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 typedef Eigen::GpuDevice GPUDevice;
     35 
     36 // Simulate quantization precision loss in a float tensor by:
     37 // 1. Quantize the tensor to fixed point numbers, which should match the target
     38 //    quantization method when it is used in inference.
     39 // 2. Dequantize it back to floating point numbers for the following ops, most
     40 //    likely matmul.
     41 template <typename Device, typename T>
     42 class QuantizeAndDequantizeV2Op : public OpKernel {
     43  public:
     44   explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
     45       : OpKernel(ctx) {
     46     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
     47     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
     48     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
     49                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
     50                                         " with signed_input_ ", signed_input_));
     51     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
     52   }
     53 
     54   void Compute(OpKernelContext* ctx) override {
     55     const Tensor& input = ctx->input(0);
     56 
     57     Tensor* output = nullptr;
     58     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
     59 
     60     Tensor input_min_tensor;
     61     Tensor input_max_tensor;
     62     if (range_given_) {
     63       input_min_tensor = ctx->input(1);
     64       input_max_tensor = ctx->input(2);
     65       auto min_val = input_min_tensor.scalar<T>()();
     66       auto max_val = input_max_tensor.scalar<T>()();
     67       OP_REQUIRES(ctx, min_val <= max_val,
     68                   errors::InvalidArgument("Invalid range: input_min ", min_val,
     69                                           " > input_max ", max_val));
     70     } else {
     71       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
     72                                              TensorShape(), &input_min_tensor));
     73       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
     74                                              TensorShape(), &input_max_tensor));
     75     }
     76 
     77     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
     78     f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
     79       range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
     80   }
     81 
     82  private:
     83   bool signed_input_;
     84   int num_bits_;
     85   bool range_given_;
     86 };
     87 
     88 // Simulate quantization precision loss in a float tensor by:
     89 // 1. Quantize the tensor to fixed point numbers, which should match the target
     90 //    quantization method when it is used in inference.
     91 // 2. Dequantize it back to floating point numbers for the following ops, most
     92 //    likely matmul.
     93 // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
     94 // tensor.
     95 template <typename Device, typename T>
     96 class QuantizeAndDequantizeV3Op : public OpKernel {
     97  public:
     98   explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
     99       : OpKernel(ctx) {
    100     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
    101     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
    102   }
    103 
    104   void Compute(OpKernelContext* ctx) override {
    105     const Tensor& input = ctx->input(0);
    106 
    107     Tensor* output = nullptr;
    108     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
    109 
    110     Tensor num_bits_tensor;
    111     num_bits_tensor = ctx->input(3);
    112     int num_bits_val = num_bits_tensor.scalar<int32>()();
    113 
    114     OP_REQUIRES(
    115         ctx, num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
    116         errors::InvalidArgument("num_bits is out of range: ", num_bits_val,
    117                                 " with signed_input_ ", signed_input_));
    118 
    119     Tensor input_min_tensor;
    120     Tensor input_max_tensor;
    121     if (range_given_) {
    122       input_min_tensor = ctx->input(1);
    123       input_max_tensor = ctx->input(2);
    124       auto min_val = input_min_tensor.scalar<T>()();
    125       auto max_val = input_max_tensor.scalar<T>()();
    126       OP_REQUIRES(ctx, min_val <= max_val,
    127                   errors::InvalidArgument("Invalid range: input_min ", min_val,
    128                                           " > input_max ", max_val));
    129     } else {
    130       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    131                                              TensorShape(), &input_min_tensor));
    132       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    133                                              TensorShape(), &input_max_tensor));
    134     }
    135 
    136     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
    137     f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_val,
    138       range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
    139   }
    140 
    141  private:
    142   bool signed_input_;
    143   bool range_given_;
    144 };
    145 
    146 // DEPRECATED: Use QuantizeAndDequantizeV2Op.
    147 template <typename Device, typename T>
    148 class QuantizeAndDequantizeOp : public OpKernel {
    149  public:
    150   explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    151     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
    152     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
    153     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
    154                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
    155                                         " with signed_input_ ", signed_input_));
    156     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
    157     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
    158     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
    159     if (range_given_) {
    160       OP_REQUIRES(
    161           ctx, input_min_ <= input_max_,
    162           errors::InvalidArgument("Invalid range: input_min ", input_min_,
    163                                   " > input_max ", input_max_));
    164     }
    165   }
    166 
    167   void Compute(OpKernelContext* ctx) override {
    168     const Tensor& input = ctx->input(0);
    169 
    170     Tensor* output = nullptr;
    171     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
    172 
    173     // One global scale.
    174     Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
    175     Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
    176     // Initialize the tensors with the values in the Attrs.
    177     input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
    178     input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
    179 
    180     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
    181     functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
    182             num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
    183             output->flat<T>());
    184   }
    185 
    186  private:
    187   bool signed_input_;
    188   int num_bits_;
    189   bool range_given_;
    190   float input_min_;
    191   float input_max_;
    192 };
    193 
    194 // Specialization for CPUDevice.
    195 namespace functor {
    196 template <typename T>
    197 struct QuantizeAndDequantizeOneScaleFunctor<CPUDevice, T> {
    198   void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input,
    199                   const bool signed_input, const int num_bits,
    200                   const bool range_given, Tensor* input_min_tensor,
    201                   Tensor* input_max_tensor, typename TTypes<T>::Vec out) {
    202     QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute(
    203         d, input, signed_input, num_bits, range_given, input_min_tensor,
    204         input_max_tensor, out);
    205   }
    206 };
    207 }  // namespace functor
    208 
    209 #define REGISTER_CPU_KERNEL(T)                                                 \
    210   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
    211                               .Device(DEVICE_CPU)                              \
    212                               .TypeConstraint<T>("T"),                         \
    213                           QuantizeAndDequantizeV2Op<CPUDevice, T>);            \
    214   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
    215                               .Device(DEVICE_CPU)                              \
    216                               .TypeConstraint<T>("T"),                         \
    217                           QuantizeAndDequantizeV3Op<CPUDevice, T>);            \
    218   REGISTER_KERNEL_BUILDER(                                                     \
    219       Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    220       QuantizeAndDequantizeOp<CPUDevice, T>);
    221 TF_CALL_float(REGISTER_CPU_KERNEL);
    222 TF_CALL_double(REGISTER_CPU_KERNEL);
    223 #undef REGISTER_CPU_KERNEL
    224 
    225 #if GOOGLE_CUDA
    226 #define REGISTER_GPU_KERNEL(T)                                                 \
    227   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
    228                               .Device(DEVICE_GPU)                              \
    229                               .HostMemory("input_max")                         \
    230                               .HostMemory("input_min")                         \
    231                               .TypeConstraint<T>("T"),                         \
    232                           QuantizeAndDequantizeV2Op<GPUDevice, T>);            \
    233   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
    234                               .Device(DEVICE_GPU)                              \
    235                               .HostMemory("input_max")                         \
    236                               .HostMemory("input_min")                         \
    237                               .HostMemory("num_bits")                          \
    238                               .TypeConstraint<T>("T"),                         \
    239                           QuantizeAndDequantizeV3Op<GPUDevice, T>);            \
    240   REGISTER_KERNEL_BUILDER(                                                     \
    241       Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    242       QuantizeAndDequantizeOp<GPUDevice, T>);
    243 TF_CALL_float(REGISTER_GPU_KERNEL);
    244 TF_CALL_double(REGISTER_GPU_KERNEL);
    245 #undef REGISTER_GPU_KERNEL
    246 #endif
    247 }  // namespace tensorflow
    248