Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     20 #include "tensorflow/core/platform/macros.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 
     25 // Gymnastics with nudged zero point is to ensure that the real zero maps to
     26 // an integer, which is required for e.g. zero-padding in convolutional layers.
     27 void CpuNudge(const float min, const float max, const float quant_min,
     28               const float quant_max, float* nudged_min, float* nudged_max,
     29               float* scale) {
     30   *scale = (max - min) / (quant_max - quant_min);
     31 
     32   const float zero_point_from_min = quant_min - min / *scale;
     33   float nudged_zero_point;
     34   if (zero_point_from_min <= quant_min) {
     35     nudged_zero_point = quant_min;
     36   } else if (zero_point_from_min >= quant_max) {
     37     nudged_zero_point = quant_max;
     38   } else {
     39     nudged_zero_point = std::round(zero_point_from_min);
     40   }
     41 
     42   *nudged_min = (quant_min - nudged_zero_point) * (*scale);
     43   *nudged_max = (quant_max - nudged_zero_point) * (*scale);
     44 }
     45 
     46 // An XLA version of CpuNudge().
     47 void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
     48               const xla::ComputationDataHandle& min,
     49               const xla::ComputationDataHandle& max,
     50               const float quant_min_value, const float quant_max_value,
     51               xla::ComputationDataHandle* nudged_min,
     52               xla::ComputationDataHandle* nudged_max,
     53               xla::ComputationDataHandle* scale) {
     54   *scale = b->Div(b->Sub(max, min),
     55                   XlaHelpers::FloatLiteral(b, data_type,
     56                                            quant_max_value - quant_min_value));
     57   xla::ComputationDataHandle quant_min =
     58       XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
     59   xla::ComputationDataHandle zero_point_from_min =
     60       b->Sub(quant_min, b->Div(min, *scale));
     61   xla::ComputationDataHandle quant_max =
     62       XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
     63   xla::ComputationDataHandle nudged_zero_point =
     64       b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
     65                 b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
     66                           b->Round(zero_point_from_min)));
     67   *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
     68   *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
     69 }
     70 
     71 xla::ComputationDataHandle Quantize(
     72     xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
     73     const DataType data_type,
     74     const xla::ComputationDataHandle& nudged_input_min,
     75     const xla::ComputationDataHandle& nudged_input_max,
     76     const xla::ComputationDataHandle& input_scale) {
     77   xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
     78   xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
     79   xla::ComputationDataHandle half =
     80       XlaHelpers::FloatLiteral(b, data_type, 0.5f);
     81 
     82   xla::ComputationDataHandle clamped =
     83       b->Clamp(nudged_input_min, input, nudged_input_max);
     84   xla::ComputationDataHandle clamped_shifted =
     85       b->Sub(clamped, nudged_input_min);
     86   xla::ComputationDataHandle rounded =
     87       b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
     88   return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
     89 }
     90 
     91 class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
     92  public:
     93   explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
     94       : XlaOpKernel(ctx) {
     95     int num_bits;
     96     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
     97     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
     98                 errors::InvalidArgument("num_bits is out of range, expected "
     99                                         "between 2 and 16, was: ",
    100                                         num_bits));
    101     bool narrow_range;
    102     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
    103     quant_min_ = narrow_range ? 1 : 0;
    104     quant_max_ = (1 << num_bits) - 1;
    105 
    106     float input_min, input_max;
    107     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
    108     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
    109     CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
    110              &nudged_input_max_, &input_scale_);
    111   }
    112 
    113   void Compile(XlaOpKernelContext* ctx) override {
    114     xla::ComputationDataHandle input = ctx->Input(0);
    115     const DataType data_type = ctx->input_type(0);
    116 
    117     xla::ComputationBuilder* b = ctx->builder();
    118     xla::ComputationDataHandle nudged_input_min =
    119         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
    120     xla::ComputationDataHandle nudged_input_max =
    121         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
    122     xla::ComputationDataHandle input_scale =
    123         XlaHelpers::FloatLiteral(b, data_type, input_scale_);
    124     xla::ComputationDataHandle output = Quantize(
    125         b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
    126     ctx->SetOutput(0, output);
    127   }
    128 
    129  private:
    130   float quant_min_;
    131   float quant_max_;
    132   float nudged_input_min_;
    133   float nudged_input_max_;
    134   float input_scale_;
    135 };
    136 
    137 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
    138 
    139 class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
    140  public:
    141   explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
    142       : XlaOpKernel(ctx) {
    143     int num_bits;
    144     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
    145     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
    146                 errors::InvalidArgument("num_bits is out of range, expected "
    147                                         "between 2 and 16, was: ",
    148                                         num_bits));
    149     bool narrow_range;
    150     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
    151     const float quant_min = narrow_range ? 1 : 0;
    152     const float quant_max = (1 << num_bits) - 1;
    153 
    154     float input_min, input_max, scale;
    155     OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
    156     OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
    157     CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
    158              &nudged_input_max_, &scale);
    159   }
    160 
    161   void Compile(XlaOpKernelContext* ctx) override {
    162     xla::ComputationDataHandle gradient = ctx->Input(0);
    163     const TensorShape gradient_shape = ctx->InputShape(0);
    164     xla::ComputationDataHandle input = ctx->Input(1);
    165     const DataType data_type = ctx->input_type(1);
    166 
    167     xla::ComputationBuilder* b = ctx->builder();
    168     xla::ComputationDataHandle nudged_input_min =
    169         XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
    170     xla::ComputationDataHandle nudged_input_max =
    171         XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
    172 
    173     xla::ComputationDataHandle between_nudged_min_max =
    174         b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
    175     xla::ComputationDataHandle zeroes = b->Broadcast(
    176         XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
    177     xla::ComputationDataHandle output =
    178         b->Select(between_nudged_min_max, gradient, zeroes);
    179     ctx->SetOutput(0, output);
    180   }
    181 
    182  private:
    183   float nudged_input_min_;
    184   float nudged_input_max_;
    185 };
    186 
    187 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
    188                 FakeQuantWithMinMaxArgsGradOp);
    189 
    190 class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
    191  public:
    192   explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
    193       : XlaOpKernel(ctx) {
    194     int num_bits;
    195     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
    196     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
    197                 errors::InvalidArgument("num_bits is out of range, expected "
    198                                         "between 2 and 16, was: ",
    199                                         num_bits));
    200     bool narrow_range;
    201     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
    202     quant_min_ = narrow_range ? 1 : 0;
    203     quant_max_ = (1 << num_bits) - 1;
    204   }
    205 
    206   void Compile(XlaOpKernelContext* ctx) override {
    207     xla::ComputationDataHandle input = ctx->Input(0);
    208     const DataType data_type = ctx->input_type(0);
    209     xla::ComputationDataHandle input_min = ctx->Input(1);
    210     xla::ComputationDataHandle input_max = ctx->Input(2);
    211 
    212     xla::ComputationBuilder* b = ctx->builder();
    213     xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
    214     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
    215              &nudged_input_min, &nudged_input_max, &input_scale);
    216 
    217     xla::ComputationDataHandle output = Quantize(
    218         b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
    219     ctx->SetOutput(0, output);
    220   }
    221 
    222  private:
    223   float quant_min_;
    224   float quant_max_;
    225 };
    226 
    227 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
    228 
    229 class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
    230  public:
    231   explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
    232       : XlaOpKernel(ctx) {
    233     int num_bits;
    234     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
    235     OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
    236                 errors::InvalidArgument("num_bits is out of range, expected "
    237                                         "between 2 and 16, was: ",
    238                                         num_bits));
    239     bool narrow_range;
    240     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
    241     quant_min_ = narrow_range ? 1 : 0;
    242     quant_max_ = (1 << num_bits) - 1;
    243   }
    244 
    245   void Compile(XlaOpKernelContext* ctx) override {
    246     xla::ComputationDataHandle gradient = ctx->Input(0);
    247     const TensorShape gradient_shape = ctx->InputShape(0);
    248     xla::ComputationDataHandle input = ctx->Input(1);
    249     const DataType data_type = ctx->input_type(1);
    250     xla::ComputationDataHandle input_min = ctx->Input(2);
    251     xla::ComputationDataHandle input_max = ctx->Input(3);
    252 
    253     xla::ComputationBuilder* b = ctx->builder();
    254     xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
    255     XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
    256              &nudged_input_min, &nudged_input_max, &input_scale);
    257 
    258     xla::ComputationDataHandle between_nudged_min_max =
    259         b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
    260     xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
    261     xla::ComputationDataHandle zeroes =
    262         b->Broadcast(zero, gradient_shape.dim_sizes());
    263     xla::ComputationDataHandle output0 =
    264         b->Select(between_nudged_min_max, gradient, zeroes);
    265     ctx->SetOutput(0, output0);
    266 
    267     xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
    268     xla::ComputationDataHandle output1 =
    269         b->ReduceAll(b->Select(below_min, gradient, zeroes), zero,
    270                      *ctx->GetOrCreateAdd(data_type));
    271     ctx->SetOutput(1, output1);
    272 
    273     xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
    274     xla::ComputationDataHandle output2 =
    275         b->ReduceAll(b->Select(above_max, gradient, zeroes), zero,
    276                      *ctx->GetOrCreateAdd(data_type));
    277     ctx->SetOutput(2, output2);
    278   }
    279 
    280  private:
    281   float quant_min_;
    282   float quant_max_;
    283 };
    284 
    285 REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
    286                 FakeQuantWithMinMaxVarsGradOp);
    287 
    288 }  // namespace
    289 }  // namespace tensorflow
    290