Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #ifdef GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "tensorflow/core/kernels/fake_quant_ops_functor.h"
     23 
     24 #include "tensorflow/core/framework/numeric_op.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/platform/protobuf.h"
     28 
     29 using tensorflow::BinaryElementWiseOp;
     30 using tensorflow::DEVICE_CPU;
     31 #if GOOGLE_CUDA
     32 using tensorflow::DEVICE_GPU;
     33 #endif
     34 using tensorflow::OpKernel;
     35 using tensorflow::OpKernelConstruction;
     36 using tensorflow::OpKernelContext;
     37 using tensorflow::Tensor;
     38 using tensorflow::TensorShape;
     39 using tensorflow::TTypes;  // NOLINT This is needed in CUDA mode, do not remove.
     40 using tensorflow::UnaryElementWiseOp;
     41 using tensorflow::errors::InvalidArgument;
     42 
     43 namespace tensorflow {
     44 
     45 typedef Eigen::ThreadPoolDevice CPUDevice;
     46 
     47 namespace {
     48 bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; }
     49 }  // namespace
     50 
     51 // -----------------------------------------------------------------------------
     52 // Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in
     53 // core/ops/array_ops.cc.
     54 template <typename Device>
     55 class FakeQuantWithMinMaxArgsOp
     56     : public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> {
     57  public:
     58   typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base;
     59   explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context)
     60       : Base::UnaryElementWiseOp(context) {
     61     OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
     62     OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
     63     OP_REQUIRES(context, min_ < max_,
     64                 InvalidArgument("min has to be smaller than max, was: ", min_,
     65                                 " >= ", max_));
     66     int num_bits;
     67     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
     68     OP_REQUIRES(
     69         context, IsNumBitsValid(num_bits),
     70         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
     71     bool narrow_range;
     72     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
     73     quant_min_ = narrow_range ? 1 : 0;
     74     quant_max_ = (1 << num_bits) - 1;
     75   }
     76 
     77   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
     78     FakeQuantWithMinMaxArgsFunctor<Device> functor;
     79     functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
     80             quant_min_, quant_max_, output->flat<float>());
     81   }
     82 
     83  private:
     84   float min_;
     85   float max_;
     86   int quant_min_;
     87   int quant_max_;
     88 };
     89 
     90 // Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
     91 // core/ops/array_ops.cc.
     92 template <typename Device>
     93 class FakeQuantWithMinMaxArgsGradientOp
     94     : public BinaryElementWiseOp<float,
     95                                  FakeQuantWithMinMaxArgsGradientOp<Device>> {
     96  public:
     97   typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>>
     98       Base;
     99   explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context)
    100       : Base::BinaryElementWiseOp(context) {
    101     OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
    102     OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
    103     OP_REQUIRES(context, min_ < max_,
    104                 InvalidArgument("min has to be smaller than max, was: ", min_,
    105                                 " >= ", max_));
    106     int num_bits;
    107     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
    108     OP_REQUIRES(
    109         context, IsNumBitsValid(num_bits),
    110         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
    111     bool narrow_range;
    112     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
    113     quant_min_ = narrow_range ? 1 : 0;
    114     quant_max_ = (1 << num_bits) - 1;
    115   }
    116 
    117   template <int NDIMS>
    118   void Operate(OpKernelContext* context, const Tensor& gradient,
    119                const Tensor& input, Tensor* output) {
    120     OperateNoTemplate(context, gradient, input, output);
    121   }
    122 
    123   void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
    124                          const Tensor& input, Tensor* output) {
    125     OP_REQUIRES(context, input.IsSameSize(gradient),
    126                 InvalidArgument("gradient and input must be the same size"));
    127     FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
    128     functor(context->eigen_device<Device>(), gradient.flat<float>(),
    129             input.flat<float>(), min_, max_, quant_min_, quant_max_,
    130             output->flat<float>());
    131   }
    132 
    133  private:
    134   float min_;
    135   float max_;
    136   int quant_min_;
    137   int quant_max_;
    138 };
    139 
    140 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
    141                         FakeQuantWithMinMaxArgsOp<CPUDevice>);
    142 REGISTER_KERNEL_BUILDER(
    143     Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
    144     FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
    145 
    146 #if GOOGLE_CUDA
    147 typedef Eigen::GpuDevice GPUDevice;
    148 
    149 // Forward declarations for functor specializations for GPU.
    150 template <>
    151 void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
    152     const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
    153     const float min, const float max, const int quant_min, const int quant_max,
    154     typename TTypes<float>::Flat outputs);
    155 extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
    156 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
    157                         FakeQuantWithMinMaxArgsOp<GPUDevice>);
    158 
    159 template <>
    160 void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
    161     const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
    162     typename TTypes<float>::ConstFlat inputs, const float min, const float max,
    163     const int quant_min, const int quant_max,
    164     typename TTypes<float>::Flat backprops);
    165 REGISTER_KERNEL_BUILDER(
    166     Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
    167     FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
    168 #endif  // GOOGLE_CUDA
    169 
    170 // -----------------------------------------------------------------------------
    171 // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
    172 // core/ops/array_ops.cc.
    173 template <typename Device>
    174 class FakeQuantWithMinMaxVarsOp : public OpKernel {
    175  public:
    176   explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context)
    177       : OpKernel::OpKernel(context) {
    178     int num_bits;
    179     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
    180     OP_REQUIRES(
    181         context, IsNumBitsValid(num_bits),
    182         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
    183     bool narrow_range;
    184     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
    185     quant_min_ = narrow_range ? 1 : 0;
    186     quant_max_ = (1 << num_bits) - 1;
    187   }
    188 
    189   void Compute(OpKernelContext* context) override {
    190     CHECK_EQ(3, context->num_inputs());
    191     const Tensor& input = context->input(0);
    192     const Tensor& min = context->input(1);
    193     const Tensor& max = context->input(2);
    194 
    195     Tensor* output;
    196     OP_REQUIRES_OK(context,
    197                    context->allocate_output(0, input.shape(), &output));
    198 
    199     FakeQuantWithMinMaxVarsFunctor<Device> functor;
    200     functor(context->eigen_device<Device>(), input.flat<float>(),
    201             min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_,
    202             output->flat<float>());
    203   }
    204 
    205  private:
    206   int quant_min_;
    207   int quant_max_;
    208 };
    209 
    210 // Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
    211 // core/ops/array_ops.cc.
    212 template <typename Device>
    213 class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
    214  public:
    215   explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context)
    216       : OpKernel::OpKernel(context) {
    217     int num_bits;
    218     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
    219     OP_REQUIRES(
    220         context, IsNumBitsValid(num_bits),
    221         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
    222     bool narrow_range;
    223     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
    224     quant_min_ = narrow_range ? 1 : 0;
    225     quant_max_ = (1 << num_bits) - 1;
    226   }
    227 
    228   void Compute(OpKernelContext* context) override {
    229     CHECK_EQ(4, context->num_inputs());
    230     const Tensor& gradient = context->input(0);
    231     const Tensor& input = context->input(1);
    232     OP_REQUIRES(context, input.IsSameSize(gradient),
    233                 InvalidArgument("gradient and input must be the same size"));
    234     const Tensor& min = context->input(2);
    235     const Tensor& max = context->input(3);
    236 
    237     Tensor* grad_wrt_input;
    238     OP_REQUIRES_OK(context,
    239                    context->allocate_output(0, input.shape(), &grad_wrt_input));
    240 
    241     TensorShape scalar_shape;
    242     Tensor* grad_wrt_min;
    243     OP_REQUIRES_OK(context,
    244                    context->allocate_output(1, scalar_shape, &grad_wrt_min));
    245 
    246     Tensor* grad_wrt_max;
    247     OP_REQUIRES_OK(context,
    248                    context->allocate_output(2, scalar_shape, &grad_wrt_max));
    249 
    250     FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
    251     functor(context->eigen_device<Device>(), gradient.flat<float>(),
    252             input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
    253             quant_min_, quant_max_, grad_wrt_input->flat<float>(),
    254             grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>());
    255   }
    256 
    257  private:
    258   int quant_min_;
    259   int quant_max_;
    260 };
    261 
    262 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
    263                         FakeQuantWithMinMaxVarsOp<CPUDevice>);
    264 REGISTER_KERNEL_BUILDER(
    265     Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
    266     FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
    267 
    268 #if GOOGLE_CUDA
    269 template <>
    270 void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
    271     const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
    272     typename TTypes<float>::ConstScalar min,
    273     typename TTypes<float>::ConstScalar max, const int quant_min,
    274     const int quant_max, typename TTypes<float>::Flat output);
    275 extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
    276 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
    277                             .Device(DEVICE_GPU)
    278                             .HostMemory("min")
    279                             .HostMemory("max"),
    280                         FakeQuantWithMinMaxVarsOp<GPUDevice>);
    281 
    282 template <>
    283 void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
    284     const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
    285     typename TTypes<float>::ConstFlat inputs,
    286     typename TTypes<float>::ConstScalar min,
    287     typename TTypes<float>::ConstScalar max, const int quant_min,
    288     const int quant_max, typename TTypes<float>::Flat backprops_wrt_input,
    289     typename TTypes<float>::Scalar backprop_wrt_min,
    290     typename TTypes<float>::Scalar backprop_wrt_max);
    291 extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
    292 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
    293                             .Device(DEVICE_GPU)
    294                             .HostMemory("min")
    295                             .HostMemory("max"),
    296                         FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
    297 #endif  // GOOGLE_CUDA
    298 
    299 // -----------------------------------------------------------------------------
    300 // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
    301 // in core/ops/array_ops.cc.
    302 template <typename Device>
    303 class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
    304  public:
    305   explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context)
    306       : OpKernel::OpKernel(context) {
    307     int num_bits;
    308     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
    309     OP_REQUIRES(
    310         context, IsNumBitsValid(num_bits),
    311         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
    312     bool narrow_range;
    313     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
    314     quant_min_ = narrow_range ? 1 : 0;
    315     quant_max_ = (1 << num_bits) - 1;
    316   }
    317 
    318   void Compute(OpKernelContext* context) override {
    319     CHECK_EQ(3, context->num_inputs());
    320     const Tensor& input = context->input(0);
    321     const int depth = input.dim_size(input.dims() - 1);  // last dimension size.
    322     const Tensor& min = context->input(1);
    323     OP_REQUIRES(context, min.dim_size(0) == depth,
    324                 InvalidArgument("min has incorrect size, expected ", depth,
    325                                 " was ", min.dim_size(0)));
    326     const Tensor& max = context->input(2);
    327     OP_REQUIRES(context, max.dim_size(0) == depth,
    328                 InvalidArgument("max has incorrect size, expected ", depth,
    329                                 " was ", max.dim_size(0)));
    330 
    331     Tensor* output;
    332     OP_REQUIRES_OK(context,
    333                    context->allocate_output(0, input.shape(), &output));
    334 
    335     FakeQuantWithMinMaxVarsPerChannelFunctor<Device> functor;
    336     functor(context->eigen_device<Device>(), input.flat_inner_dims<float, 2>(),
    337             min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
    338             output->flat_inner_dims<float, 2>());
    339   }
    340 
    341  private:
    342   int quant_min_;
    343   int quant_max_;
    344 };
    345 
    346 // Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
    347 // documentation in core/ops/array_ops.cc.
    348 template <typename Device>
    349 class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
    350  public:
    351   explicit FakeQuantWithMinMaxVarsPerChannelGradientOp(
    352       OpKernelConstruction* context)
    353       : OpKernel::OpKernel(context) {
    354     int num_bits;
    355     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
    356     OP_REQUIRES(
    357         context, IsNumBitsValid(num_bits),
    358         InvalidArgument("num_bits must be between 2 and 16, inclusive"));
    359     bool narrow_range;
    360     OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
    361     quant_min_ = narrow_range ? 1 : 0;
    362     quant_max_ = (1 << num_bits) - 1;
    363   }
    364 
    365   void Compute(OpKernelContext* context) override {
    366     CHECK_EQ(4, context->num_inputs());
    367     const Tensor& gradient = context->input(0);
    368     const Tensor& input = context->input(1);
    369     OP_REQUIRES(context, input.IsSameSize(gradient),
    370                 InvalidArgument("gradient and input must be the same size"));
    371     const int depth = input.dim_size(input.dims() - 1);  // last dimension size.
    372     const Tensor& min = context->input(2);
    373     OP_REQUIRES(context, min.dim_size(0) == depth,
    374                 InvalidArgument("min has incorrect size, expected ", depth,
    375                                 " was ", min.dim_size(0)));
    376     const Tensor& max = context->input(3);
    377     OP_REQUIRES(context, max.dim_size(0) == depth,
    378                 InvalidArgument("max has incorrect size, expected ", depth,
    379                                 " was ", max.dim_size(0)));
    380 
    381     Tensor* grad_wrt_input;
    382     OP_REQUIRES_OK(context,
    383                    context->allocate_output(0, input.shape(), &grad_wrt_input));
    384 
    385     TensorShape min_max_shape({input.dim_size(input.dims() - 1)});
    386     Tensor* grad_wrt_min;
    387     OP_REQUIRES_OK(context,
    388                    context->allocate_output(1, min_max_shape, &grad_wrt_min));
    389 
    390     Tensor* grad_wrt_max;
    391     OP_REQUIRES_OK(context,
    392                    context->allocate_output(2, min_max_shape, &grad_wrt_max));
    393 
    394     FakeQuantWithMinMaxVarsPerChannelGradientFunctor<Device> functor;
    395     functor(
    396         context->eigen_device<Device>(), gradient.flat_inner_dims<float, 2>(),
    397         input.flat_inner_dims<float, 2>(), min.vec<float>(), max.vec<float>(),
    398         quant_min_, quant_max_, grad_wrt_input->flat_inner_dims<float, 2>(),
    399         grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
    400   }
    401 
    402  private:
    403   int quant_min_;
    404   int quant_max_;
    405 };
    406 
    407 REGISTER_KERNEL_BUILDER(
    408     Name("FakeQuantWithMinMaxVarsPerChannel").Device(DEVICE_CPU),
    409     FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>);
    410 REGISTER_KERNEL_BUILDER(
    411     Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
    412     FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
    413 
    414 #if GOOGLE_CUDA
    415 template <>
    416 void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
    417     const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
    418     typename TTypes<float>::ConstFlat min,
    419     typename TTypes<float>::ConstFlat max, const int quant_min,
    420     const int quant_max, typename TTypes<float>::Matrix outputs);
    421 extern template struct FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>;
    422 
    423 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
    424                             .Device(DEVICE_GPU)
    425                             .HostMemory("min")
    426                             .HostMemory("max"),
    427                         FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>);
    428 
    429 template <>
    430 void FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
    431     const GPUDevice& d, typename TTypes<float>::ConstMatrix gradients,
    432     typename TTypes<float>::ConstMatrix inputs,
    433     typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
    434     const int quant_min, const int quant_max,
    435     typename TTypes<float>::Matrix backprops_wrt_input,
    436     typename TTypes<float>::Vec backprop_wrt_min,
    437     typename TTypes<float>::Vec backprop_wrt_max);
    438 extern template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor<
    439     GPUDevice>;
    440 
    441 REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
    442                             .Device(DEVICE_GPU)
    443                             .HostMemory("min")
    444                             .HostMemory("max"),
    445                         FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
    446 #endif  // GOOGLE_CUDA
    447 
    448 }  // namespace tensorflow
    449