Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     20 #include "tensorflow/core/platform/macros.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 
     25 class QuantizeAndDequantizeOp : public XlaOpKernel {
     26  public:
     27   explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx)
     28       : XlaOpKernel(ctx) {
     29     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
     30     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
     31     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
     32     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
     33                 errors::InvalidArgument("num_bits is out of range: ", num_bits_,
     34                                         " with signed_input_ ", signed_input_));
     35   }
     36 
     37   void Compile(XlaOpKernelContext* ctx) override {
     38     xla::ComputationDataHandle input = ctx->Input(0);
     39     const DataType data_type = ctx->input_type(0);
     40 
     41     // Comments taken from semantics description at
     42     // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize
     43     //
     44     // ... we find m such that
     45     //
     46     // m = max(abs(input_min), abs(input_max)) if range_given is true,
     47     // m = max(abs(min_elem(input)),
     48     //         abs(max_elem(input))) otherwise.
     49     xla::ComputationBuilder* b = ctx->builder();
     50     xla::ComputationDataHandle input_min, input_max;
     51     if (range_given_) {
     52       double input_min_value, input_max_value;
     53       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
     54       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value));
     55       input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
     56       input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
     57     } else {
     58       const xla::Computation* fmax = ctx->GetOrCreateMax(data_type);
     59       const xla::Computation* fmin = ctx->GetOrCreateMin(data_type);
     60       input_min =
     61           b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
     62       input_max =
     63           b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
     64     }
     65     xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max));
     66 
     67     // Next, we choose our fixed-point quantization buckets, [min_fixed,
     68     // max_fixed]. If signed_input is true, this is
     69     //
     70     // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1),
     71     //                             (1 << (num_bits - 1)) - 1].
     72     //
     73     // Otherwise, if signed_input is false, the fixed-point range is
     74     //
     75     // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1].
     76     int64 min_fixed, max_fixed;
     77     if (signed_input_) {
     78       min_fixed = -((1LL << (num_bits_ - 1)) - 1);
     79       max_fixed = (1LL << (num_bits_ - 1)) - 1;
     80     } else {
     81       min_fixed = 0;
     82       max_fixed = (1LL << num_bits_) - 1;
     83     }
     84 
     85     // From this we compute our scaling factor, s:
     86     //
     87     // s = (max_fixed - min_fixed) / (2 * m).
     88     xla::ComputationDataHandle s =
     89         b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
     90                b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
     91 
     92     // Now we can quantize and dequantize the elements of our tensor. An element
     93     // e is transformed into e':
     94     //
     95     // e' = (e * s).round_to_nearest() / s.
     96     xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s);
     97 
     98     ctx->SetOutput(0, result);
     99   }
    100 
    101   int64 num_bits_;
    102   bool signed_input_;
    103   bool range_given_;
    104 };
    105 
    106 REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp);
    107 
    108 }  // namespace
    109 }  // namespace tensorflow
    110