Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/framework/op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/type_traits.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/kernels/cwise_ops.h"
     25 #include "tensorflow/core/kernels/meta_support.h"
     26 #include "tensorflow/core/kernels/quantization_utils.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 
     29 namespace {
     30 enum {
     31   QUANTIZE_MODE_MIN_COMBINED,
     32   QUANTIZE_MODE_MIN_FIRST,
     33   QUANTIZE_MODE_SCALED,
     34 };
     35 enum {
     36   // Round half away from zero: if the fraction of y is exactly 0.5, then
     37   // round(y) = y + 0.5 if y > 0
     38   // round(y) = y - 0.5 if y < 0
     39   // E.g., -5.5 gets rounded to -6, -5.4 goes to -5,
     40   // 5.4 goes to 5, and 5.5 goes to 6.
     41   ROUND_HALF_AWAY_FROM_ZERO,
     42   // Round half to even: if the fraction of y is exactly 0.5, then round(y) is
     43   // the nearest even integer to y.
     44   // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes
     45   // -24, and -24.5 gets rounded to 24.
     46   ROUND_HALF_TO_EVEN,
     47 };
     48 }  // namespace
     49 
     50 namespace tensorflow {
     51 
     52 typedef Eigen::ThreadPoolDevice CPUDevice;
     53 
     54 // Quantize a tensor from float to T, with user-specified min_range and
     55 // max_range.
     56 // TODO(xbing): Add a new QuantizeOp just taking scale,
     57 //              rather than min_range and max_range.
     58 template <typename Device, typename T>
     59 class QuantizeV2Op : public OpKernel {
     60  public:
     61   explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
     62     half_range_ =
     63         !std::is_signed<T>::value
     64             ? 0.0f
     65             : (static_cast<double>(std::numeric_limits<T>::max()) -
     66                static_cast<double>(std::numeric_limits<T>::min()) + 1) /
     67                   2.0f;
     68     string mode_string;
     69     OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
     70     OP_REQUIRES(ctx,
     71                 (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
     72                  mode_string == "SCALED"),
     73                 errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
     74                                         " 'MIN_FIRST', or 'SCALED', is '" +
     75                                         mode_string + "'"));
     76     if (mode_string == "MIN_COMBINED") {
     77       mode_ = QUANTIZE_MODE_MIN_COMBINED;
     78     } else if (mode_string == "MIN_FIRST") {
     79       mode_ = QUANTIZE_MODE_MIN_FIRST;
     80     } else if (mode_string == "SCALED") {
     81       mode_ = QUANTIZE_MODE_SCALED;
     82     }
     83 
     84     string round_mode_string;
     85     OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
     86     OP_REQUIRES(ctx,
     87                 (round_mode_string == "HALF_AWAY_FROM_ZERO" ||
     88                  round_mode_string == "HALF_TO_EVEN"),
     89                 errors::InvalidArgument("Round mode string must be "
     90                                         "'HALF_AWAY_FROM_ZERO' or "
     91                                         "'HALF_TO_EVEN', is '" +
     92                                         round_mode_string + "'"));
     93     if (round_mode_string == "HALF_AWAY_FROM_ZERO") {
     94       round_mode_ = ROUND_HALF_AWAY_FROM_ZERO;
     95     } else if (round_mode_string == "HALF_TO_EVEN") {
     96       OP_REQUIRES(ctx, mode_string == "SCALED",
     97                   errors::InvalidArgument("Round mode 'HALF_TO_EVEN' "
     98                                           "only supported for mode 'SCALED', "
     99                                           "but mode is '" +
    100                                           mode_string + "'."));
    101       round_mode_ = ROUND_HALF_TO_EVEN;
    102     }
    103   }
    104 
    105   void Compute(OpKernelContext* ctx) override {
    106     const Tensor& input = ctx->input(0);
    107     const float input_min_range = ctx->input(1).flat<float>()(0);
    108     const float input_max_range = ctx->input(2).flat<float>()(0);
    109 
    110     float min_range;
    111     float max_range;
    112     OP_REQUIRES(ctx, !(input_max_range < input_min_range),
    113                 errors::InvalidArgument(
    114                     "input_max_range must be larger than input_min_range."));
    115 
    116     // When the minimum and maximum ranges are too close together, nudge them
    117     // apart by a small value so that they are slightly different. This helps
    118     // us avoid creating ill-formed buffers where all quantized values map to
    119     // the same float number. These kinds of buffers cause problems for
    120     // downstream ops when they need to do calculations on them.
    121     // We pick the value by making sure that zero is not more than 100x the
    122     // overall range from the maximum, so that the value can be easily
    123     // represented when we promote the quantized value to a higher
    124     // intermediate bit depth, since that's a common requirement.
    125     min_range = std::min(0.0f, input_min_range);
    126     const float epsilon = std::max(1.0f, std::max(fabsf(input_min_range),
    127                                                   fabsf(input_max_range))) /
    128                           100.0f;
    129     max_range = std::max(input_max_range, min_range + epsilon);
    130     max_range = std::max(0.0f, max_range);
    131 
    132     Tensor* output = nullptr;
    133     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
    134     if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
    135       const float scale_factor =
    136           (static_cast<double>(std::numeric_limits<T>::max()) -
    137            static_cast<double>(std::numeric_limits<T>::min())) /
    138           (max_range - min_range);
    139 
    140       // Quantize:
    141       // Make input in range of [min_range, max_range], then
    142       // subtract min_range to be in range of [0, max_range - min_range]
    143       // Divide by (max_range - min_range) to get to [0, 1.0]
    144       // Multiply by range of T, after that shift left 1/2 range of T if
    145       // T is signed.
    146       // Note that the number is rounded before the cast. Rounding follows the
    147       // semantic of std::round, which implements "round-half-away-zero",
    148       // e.g., -5.5 gets rounded to -6, -5.4 goes to -5, 5.4 goes to 5,
    149       // and 5.5 goes to 6.
    150       typename TTypes<T>::Vec o = output->template flat<T>();
    151       bool is_signed = std::is_signed<T>::value;
    152       if (is_signed) {
    153         // The slow path.
    154         // TODO(xbing,yonghui): Speedup this path as well.
    155         o.device(ctx->template eigen_device<Device>()) =
    156             ((input.flat<float>().cwiseMin(max_range).cwiseMax(min_range) -
    157               min_range) *
    158                  scale_factor -
    159              half_range_)
    160                 .round()
    161                 .template cast<T>();
    162       } else {
    163         // The fast path that avoids unaryExpr
    164         // According to the micro-benchmark, adding device here doesn't help.
    165         o = ((input.flat<float>().cwiseMin(max_range).cwiseMax(min_range) -
    166               min_range) *
    167                  scale_factor +
    168              0.5f)
    169                 .template cast<T>();
    170       }
    171     } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
    172       if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
    173         TTypes<const float>::Vec input_array = input.flat<float>();
    174 
    175         meta::Quantize(ctx, input_array.data(), input_array.size(), min_range,
    176                        max_range, output->flat<quint8>().data());
    177       } else {
    178         FloatTensorToQuantizedInPlaceUsingEigen<T>(
    179             ctx->template eigen_device<Device>(), input, min_range, max_range,
    180             output);
    181       }
    182     } else if (mode_ == QUANTIZE_MODE_SCALED) {
    183       // The quantization logic for mode SCALED matches that of
    184       // QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
    185       typename TTypes<T>::Vec o = output->template flat<T>();
    186       static constexpr int num_bits = sizeof(T) * 8;
    187       const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
    188       const bool is_signed = std::is_signed<T>::value;
    189       float target_range;
    190       if (is_signed) {
    191         max_range = max_abs;
    192         min_range = -max_abs;
    193         // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
    194         // example, if it is 8 bits, we have the range [-127, 127]. So for input
    195         // range of [-x, x], the scale should be 254/(2*x).
    196         target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
    197       } else {
    198         max_range = max_abs;
    199         min_range = 0.0;
    200         // If it is unsigned and num_bits == 8, the range with 8 bits is [0,
    201         // 255].  If the input range is [0, x], then the scale is x/255 instead
    202         // of 254 as in the case above.
    203         target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
    204       }
    205       const float scale_factor = target_range / max_abs;
    206       if (round_mode_ == ROUND_HALF_TO_EVEN) {
    207         // scalar_round_op_google implements "round-half-to-even".
    208         o.device(ctx->template eigen_device<Device>()) =
    209             (input.flat<float>().cwiseMin(max_range).cwiseMax(min_range) *
    210              scale_factor)
    211                 .unaryExpr(Eigen::internal::scalar_round_op_google<float>())
    212                 .template cast<T>();
    213       } else if (round_mode_ == ROUND_HALF_AWAY_FROM_ZERO) {
    214         // scalar_round_op implements "round-half-away-from-zero".
    215         o.device(ctx->template eigen_device<Device>()) =
    216             (input.flat<float>().cwiseMin(max_range).cwiseMax(min_range) *
    217              scale_factor)
    218                 .unaryExpr(Eigen::internal::scalar_round_op<float>())
    219                 .template cast<T>();
    220       }
    221     }
    222 
    223     Tensor* output_min_tensor = nullptr;
    224     OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {}, &output_min_tensor));
    225     output_min_tensor->flat<float>()(0) = min_range;
    226 
    227     Tensor* output_max_tensor = nullptr;
    228     OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {}, &output_max_tensor));
    229     output_max_tensor->flat<float>()(0) = max_range;
    230   }
    231 
    232  private:
    233   float half_range_;
    234   int mode_;
    235   int round_mode_;
    236 };
    237 
    238 REGISTER_KERNEL_BUILDER(
    239     Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
    240     QuantizeV2Op<CPUDevice, quint8>);
    241 REGISTER_KERNEL_BUILDER(
    242     Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
    243     QuantizeV2Op<CPUDevice, qint8>);
    244 REGISTER_KERNEL_BUILDER(
    245     Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
    246     QuantizeV2Op<CPUDevice, quint16>);
    247 REGISTER_KERNEL_BUILDER(
    248     Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
    249     QuantizeV2Op<CPUDevice, qint16>);
    250 REGISTER_KERNEL_BUILDER(
    251     Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
    252     QuantizeV2Op<CPUDevice, qint32>);
    253 }  // namespace tensorflow
    254