Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/framework/op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/type_traits.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/kernels/meta_support.h"
     25 #include "tensorflow/core/kernels/quantization_utils.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 
     28 namespace {
     29 enum {
     30   QUANTIZE_MODE_MIN_COMBINED,
     31   QUANTIZE_MODE_MIN_FIRST,
     32   QUANTIZE_MODE_SCALED,
     33 };
     34 }  // namespace
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 
     40 template <typename Device, typename T>
     41 class DequantizeOp : public OpKernel {
     42  public:
     43   explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     44     half_range_ = !std::is_signed<T>::value
     45                       ? 0.0f
     46                       : (static_cast<float>(std::numeric_limits<T>::max()) -
     47                          std::numeric_limits<T>::min() + 1) /
     48                             2.0f;
     49     string mode_string;
     50     OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
     51     OP_REQUIRES(ctx,
     52                 (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
     53                  mode_string == "SCALED"),
     54                 errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
     55                                         " 'MIN_FIRST', or 'SCALED', is '" +
     56                                         mode_string + "'"));
     57     if (mode_string == "MIN_COMBINED") {
     58       mode_ = QUANTIZE_MODE_MIN_COMBINED;
     59     } else if (mode_string == "MIN_FIRST") {
     60       mode_ = QUANTIZE_MODE_MIN_FIRST;
     61     } else if (mode_string == "SCALED") {
     62       mode_ = QUANTIZE_MODE_SCALED;
     63     }
     64   }
     65 
     66   void Compute(OpKernelContext* ctx) override {
     67     const Tensor& input = ctx->input(0);
     68     const float min_range = ctx->input(1).flat<float>()(0);
     69     const float max_range = ctx->input(2).flat<float>()(0);
     70 
     71     Tensor* output = nullptr;
     72     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
     73     if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
     74       const float scale_factor =
     75           (max_range - min_range) /
     76           (static_cast<float>(std::numeric_limits<T>::max()) -
     77            std::numeric_limits<T>::min());
     78 
     79       float* out_ptr = output->flat<float>().data();
     80       const T* in_ptr = input.flat<T>().data();
     81 
     82       const int64 num_elements = input.NumElements();
     83       for (int i = 0; i < num_elements; ++i) {
     84         out_ptr[i] =
     85             ((static_cast<int>(in_ptr[i]) + half_range_) * scale_factor) +
     86             min_range;
     87       }
     88     } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
     89       if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
     90         auto input_ui8_array = input.flat<quint8>();
     91         meta::Dequantize(ctx, input_ui8_array.data(), input_ui8_array.size(),
     92                          min_range, max_range, output->flat<float>().data());
     93       } else {
     94         QuantizedTensorToFloatInPlaceUsingEigen<T>(
     95             ctx->template eigen_device<Device>(), input, min_range, max_range,
     96             output);
     97       }
     98     } else if (mode_ == QUANTIZE_MODE_SCALED) {
     99       // The quantization logic for mode SCALED matches that of
    100       // QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
    101       static constexpr int num_bits = sizeof(T) * 8;
    102       const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
    103       bool is_signed = std::is_signed<T>::value;
    104       // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
    105       // example, if it is 8 bits, we have the range [-127, 127]. So for input
    106       // range of [-x, x], the scale should be 254/(2*x).
    107       //
    108       // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255].
    109       // If the input range is [0, x], then the scale is x/255 instead of 254 as
    110       // in the case above.
    111       const int target_bits = is_signed ? (num_bits - 1) : num_bits;
    112       const float target_range =
    113           static_cast<float>((uint64_t{1} << target_bits) - 1);
    114       const float scale_factor = max_abs / target_range;
    115       float* out_ptr = output->flat<float>().data();
    116       const T* in_ptr = input.flat<T>().data();
    117 
    118       const int64 num_elements = input.NumElements();
    119       for (int i = 0; i < num_elements; ++i) {
    120         out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
    121       }
    122     }
    123   }
    124 
    125  private:
    126   float half_range_;
    127   int mode_;
    128 };
    129 
    130 REGISTER_KERNEL_BUILDER(
    131     Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
    132     DequantizeOp<CPUDevice, quint8>);
    133 REGISTER_KERNEL_BUILDER(
    134     Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
    135     DequantizeOp<CPUDevice, qint8>);
    136 REGISTER_KERNEL_BUILDER(
    137     Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
    138     DequantizeOp<CPUDevice, quint16>);
    139 REGISTER_KERNEL_BUILDER(
    140     Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
    141     DequantizeOp<CPUDevice, qint16>);
    142 
    143 REGISTER_KERNEL_BUILDER(
    144     Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
    145     DequantizeOp<CPUDevice, qint32>);
    146 
    147 }  // namespace tensorflow
    148