Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/adjust_contrast_op.h"
     20 #include <memory>
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 
     30 namespace tensorflow {
     31 
     32 typedef Eigen::ThreadPoolDevice CPUDevice;
     33 typedef Eigen::GpuDevice GPUDevice;
     34 #ifdef TENSORFLOW_USE_SYCL
     35 typedef Eigen::SyclDevice SYCLDevice;
     36 #endif
     37 
     38 // AdjustContrastOp is deprecated as of GraphDef version >= 2
     39 
     40 template <typename Device, typename T>
     41 class AdjustContrastOp : public OpKernel {
     42  public:
     43   explicit AdjustContrastOp(OpKernelConstruction* context)
     44       : OpKernel(context) {}
     45 
     46   void Compute(OpKernelContext* context) override {
     47     const Tensor& input = context->input(0);
     48     const Tensor& factor = context->input(1);
     49     const Tensor& min_value = context->input(2);
     50     const Tensor& max_value = context->input(3);
     51     OP_REQUIRES(context, input.dims() >= 3,
     52                 errors::InvalidArgument("input must be at least 3-D, got shape",
     53                                         input.shape().DebugString()));
     54     const int64 height = input.dim_size(input.dims() - 3);
     55     const int64 width = input.dim_size(input.dims() - 2);
     56     const int64 channels = input.dim_size(input.dims() - 1);
     57 
     58     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor.shape()),
     59                 errors::InvalidArgument("contrast_factor must be scalar: ",
     60                                         factor.shape().DebugString()));
     61     OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_value.shape()),
     62                 errors::InvalidArgument("min_value must be scalar: ",
     63                                         min_value.shape().DebugString()));
     64     OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_value.shape()),
     65                 errors::InvalidArgument("max_value must be scalar: ",
     66                                         max_value.shape().DebugString()));
     67 
     68     Tensor* output = nullptr;
     69     OP_REQUIRES_OK(context,
     70                    context->allocate_output(0, input.shape(), &output));
     71 
     72     Tensor mean_values;
     73     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
     74                                                    TensorShape(input.shape()),
     75                                                    &mean_values));
     76 
     77     if (input.NumElements() > 0) {
     78       const int64 batch = input.NumElements() / (height * width * channels);
     79       const int64 shape[4] = {batch, height, width, channels};
     80       functor::AdjustContrast<Device, T>()(
     81           context->eigen_device<Device>(), input.shaped<T, 4>(shape),
     82           factor.scalar<float>(), min_value.scalar<float>(),
     83           max_value.scalar<float>(), mean_values.shaped<float, 4>(shape),
     84           output->shaped<float, 4>(shape));
     85     }
     86   }
     87 };
     88 
     89 #define REGISTER_KERNEL(T)                                              \
     90   REGISTER_KERNEL_BUILDER(                                              \
     91       Name("AdjustContrast").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
     92       AdjustContrastOp<CPUDevice, T>);
     93 
     94 REGISTER_KERNEL(uint8);
     95 REGISTER_KERNEL(int8);
     96 REGISTER_KERNEL(int16);
     97 REGISTER_KERNEL(int32);
     98 REGISTER_KERNEL(float);
     99 REGISTER_KERNEL(double);
    100 #undef REGISTER_KERNEL
    101 
    102 #if GOOGLE_CUDA
    103 // Forward declarations of the function specializations for GPU (to prevent
    104 // building the GPU versions here, they will be built compiling _gpu.cu.cc).
    105 namespace functor {
    106 #define DECLARE_GPU_SPEC(T)                                         \
    107   template <>                                                       \
    108   void AdjustContrast<GPUDevice, T>::operator()(                    \
    109       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
    110       typename TTypes<float>::ConstScalar contrast_factor,          \
    111       typename TTypes<float>::ConstScalar min_value,                \
    112       typename TTypes<float>::ConstScalar max_value,                \
    113       typename TTypes<float, 4>::Tensor mean_values,                \
    114       typename TTypes<float, 4>::Tensor output);                    \
    115   extern template struct AdjustContrast<GPUDevice, T>;
    116 
    117 DECLARE_GPU_SPEC(uint8);
    118 DECLARE_GPU_SPEC(int8);
    119 DECLARE_GPU_SPEC(int16);
    120 DECLARE_GPU_SPEC(int32);
    121 DECLARE_GPU_SPEC(float);
    122 DECLARE_GPU_SPEC(double);
    123 #undef DECLARE_GPU_SPEC
    124 }  // namespace functor
    125 
    126 // Registration of the GPU implementations.
    127 #define REGISTER_GPU_KERNEL(T)                                          \
    128   REGISTER_KERNEL_BUILDER(                                              \
    129       Name("AdjustContrast").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    130       AdjustContrastOp<GPUDevice, T>);
    131 REGISTER_GPU_KERNEL(uint8);
    132 REGISTER_GPU_KERNEL(int8);
    133 REGISTER_GPU_KERNEL(int16);
    134 REGISTER_GPU_KERNEL(int32);
    135 REGISTER_GPU_KERNEL(float);
    136 REGISTER_GPU_KERNEL(double);
    137 #undef REGISTER_GPU_KERNEL
    138 
    139 #endif  // GOOGLE_CUDA
    140 
    141 class AdjustContrastOpV2Base : public OpKernel {
    142  protected:
    143   explicit AdjustContrastOpV2Base(OpKernelConstruction* context)
    144       : OpKernel(context) {}
    145 
    146   struct ComputeOptions {
    147     const Tensor* input = nullptr;
    148     const Tensor* factor = nullptr;
    149     Tensor* output = nullptr;
    150     int64 batch = 0;
    151     int64 height = 0;
    152     int64 width = 0;
    153     int64 channels = 0;
    154   };
    155 
    156   void Compute(OpKernelContext* context) override {
    157     const Tensor& input = context->input(0);
    158     const Tensor& factor = context->input(1);
    159     OP_REQUIRES(context, input.dims() >= 3,
    160                 errors::InvalidArgument("input must be at least 3-D, got shape",
    161                                         input.shape().DebugString()));
    162     const int64 height = input.dim_size(input.dims() - 3);
    163     const int64 width = input.dim_size(input.dims() - 2);
    164     const int64 channels = input.dim_size(input.dims() - 1);
    165 
    166     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor.shape()),
    167                 errors::InvalidArgument("contrast_factor must be scalar: ",
    168                                         factor.shape().DebugString()));
    169 
    170     Tensor* output = nullptr;
    171     OP_REQUIRES_OK(context,
    172                    context->allocate_output(0, input.shape(), &output));
    173 
    174     if (input.NumElements() > 0) {
    175       const int64 batch = input.NumElements() / (height * width * channels);
    176       ComputeOptions options;
    177       options.input = &input;
    178       options.factor = &factor;
    179       options.output = output;
    180       options.batch = batch;
    181       options.height = height;
    182       options.width = width;
    183       options.channels = channels;
    184       DoCompute(context, options);
    185     }
    186   }
    187 
    188   virtual void DoCompute(OpKernelContext* context,
    189                          const ComputeOptions& options) = 0;
    190 };
    191 
    192 template <typename Device>
    193 class AdjustContrastOpv2;
    194 
    195 template <>
    196 class AdjustContrastOpv2<CPUDevice> : public AdjustContrastOpV2Base {
    197  public:
    198   explicit AdjustContrastOpv2(OpKernelConstruction* context)
    199       : AdjustContrastOpV2Base(context) {}
    200 
    201   void DoCompute(OpKernelContext* context,
    202                  const ComputeOptions& options) override {
    203     const int64 batch = options.batch;
    204     const int64 height = options.height;
    205     const int64 width = options.width;
    206     const int64 channels = options.channels;
    207     const int64 image_size = height * width;
    208     const Tensor* input = options.input;
    209     const Tensor* factor = options.factor;
    210     Tensor* output = options.output;
    211     Tensor mean_values;
    212     OP_REQUIRES_OK(context, context->allocate_temp(
    213                                 DataTypeToEnum<float>::value,
    214                                 TensorShape({batch, channels}), &mean_values));
    215     // TODO(zhengxq): for multiple batches, shard them into different batches.
    216     auto input_data = input->shaped<float, 3>({batch, image_size, channels});
    217     auto mean_data = mean_values.tensor<float, 2>();
    218     auto output_data = output->shaped<float, 3>({batch, image_size, channels});
    219 
    220     // Calculate the mean of the inputs.
    221     ReduceMeanAcrossImage(input_data, mean_data, output_data);
    222     // Broadcast the mean into the outputs.
    223     BroadcastAcrossImage(mean_data, output_data);
    224     // Increment the outputs with the scaled difference through their flat
    225     // structure.
    226     IncrementWithScaling(input_data, factor->scalar<float>(), output_data);
    227   }
    228 
    229  private:
    230   // Reduce the mean of the inputs along the image dimension, i.e. dim_1, in a
    231   // 3D tensor. Effectively means(i, k) = inputs(i, :, k).mean().
    232   void ReduceMeanAcrossImage(typename TTypes<float, 3>::ConstTensor input,
    233                              typename TTypes<float, 2>::Tensor mean,
    234                              typename TTypes<float, 3>::Tensor scratch) {
    235     const int64 batch = input.dimension(0);
    236     const int64 image_size = input.dimension(1);
    237     const int64 channels = input.dimension(2);
    238     TTypes<float, 1>::ConstTensor input_flat(&input(0, 0, 0), input.size());
    239     TTypes<float, 1>::Tensor mean_flat(&mean(0, 0), mean.size());
    240     TTypes<float, 1>::Tensor summation_scratch(&scratch(0, 0, 0),
    241                                                scratch.size());
    242     typedef Eigen::array<Eigen::DenseIndex, 1> Index;
    243     const int64 plane_size = image_size * channels;
    244     // Since the number of channels in the early layers is often small, a
    245     // straightforward loop for summing cannot utilize vectorization.
    246     // This algorithm repeatedly folds each image plane by half, until
    247     // only one set of channels remains.
    248     for (int64 i = 0; i < batch; i++) {
    249       auto input_plane =
    250           input_flat.slice(Index(i * plane_size), Index(plane_size));
    251       auto summation_plane =
    252           summation_scratch.slice(Index(i * plane_size), Index(plane_size));
    253       int64 remaining_size = image_size;
    254       int round = 0;
    255       // Sum the input(i, :, k) into mean(i, k). Repeatedly splits the input
    256       // array into half and sums the two halves, until only one set of channels
    257       // is left, which holds the sum. Since each half is large enough, this
    258       // leads to much better vectorizations between components. An example of
    259       // how this works:
    260       //
    261       //   x = float[4096, 3]
    262       //   round 0
    263       //     y[:2048, :] = x[:2048, :] + x[2048:, :]
    264       //   round 1
    265       //     y[:1024, :] += y[1024:2048, :]
    266       //   round 2
    267       //     y[:512, :] += y[512:1024, :]
    268       //   ...
    269       //   round 11
    270       //     y[:1, :] += y[1:2, :]
    271       //   At this point y[0, :] holds the sum of all x[:, :]
    272       //
    273       // The algorithm itself can handle size that is not power-of-two. Note
    274       // that in each round we sum up elements that are contiguous. So we can
    275       // use their flattened structure to gain vectorinization efficiency.
    276       do {
    277         int64 right_size = remaining_size / 2;
    278         int64 left_size = remaining_size - right_size;
    279         DCHECK(left_size == right_size || left_size == right_size + 1);
    280         if (round == 0) {
    281           // In the first round, sum the left side and right side of the input
    282           // array into the summation area.
    283           summation_plane.slice(Index(0), Index(right_size * channels)) =
    284               input_plane.slice(Index(left_size * channels),
    285                                 Index(right_size * channels)) +
    286               input_plane.slice(Index(0), Index(right_size * channels));
    287           if (left_size > right_size) {
    288             DCHECK_EQ(left_size - right_size, 1);
    289             // Copy over the remaining column if the remaining_size is odd.
    290             // This also handles the case where image_size == 1.
    291             summation_plane.slice(Index(right_size * channels),
    292                                   Index(channels)) =
    293                 input_plane.slice(Index(right_size * channels),
    294                                   Index(channels));
    295           }
    296         } else {
    297           // For all the remaining rounds, add the second half of the inputs
    298           // into the first half of the inputs. With the flat structure and
    299           // large size, this utilizes vectorization between components.
    300           summation_plane.slice(Index(0), Index(right_size * channels)) +=
    301               summation_plane.slice(Index(left_size * channels),
    302                                     Index(right_size * channels));
    303         }
    304         remaining_size = left_size;
    305         round++;
    306       } while (remaining_size > 1);
    307       const float mean_scaling = 1.0f / image_size;
    308       // The first channels elements in summation_plane now holds the summation.
    309       // Scale it with image_size and copy over to the means.
    310       auto mean_plane = mean_flat.slice(Index(i * channels), Index(channels));
    311       mean_plane =
    312           summation_plane.slice(Index(0), Index(channels)) * mean_scaling;
    313     }
    314   }
    315 
    316   // Broadcast a 2D inputs into a 3D outputs across the image dimension, i.e.,
    317   // dim-1.
    318   void BroadcastAcrossImage(typename TTypes<float, 2>::Tensor inputs,
    319                             typename TTypes<float, 3>::Tensor outputs) {
    320     int64 batch = outputs.dimension(0);
    321     int64 image_size = outputs.dimension(1);
    322     int64 channels = outputs.dimension(2);
    323     // Similar to the reduction case, a straighforward implementation of this
    324     // does not utilize vectorization well because of the small channel size.
    325     // This algorithm repeatedly increases the area to be copied, and leads to
    326     // much better vectorinizations in the copy.
    327     for (int64 i = 0; i < batch; i++) {
    328       // Copy over the inputs into outputs in this batch. Effectively:
    329       // outputs(i, :, k) = inputs(i, k). An example of how this algorith works:
    330       //
    331       //    x = float[1, 3], y = float[2048, 3]
    332       //    round 0
    333       //      y[:1, :] = x[:, :]
    334       //    round 1
    335       //      y[1:2, :] = y[:1, :]
    336       //    round 2
    337       //      y[2:4, :] = y[:2, :]
    338       //    round 3
    339       //      y[4:8, :] = y[:4, :]
    340       //    ...
    341       //    round 11
    342       //      y[1024:2048, :] = y[:1024, :]
    343       //    At this point y[:, k] == x[k]
    344       //
    345       // The algorithm works for size that is not power-of-two. For each round,
    346       // the elements that are copied are continuous, so it benefits from the
    347       // vectorized copy via memcpy.
    348       const float* mean_p = &inputs(i, 0);
    349       // Copy the first set of channels.
    350       float* output_p = &outputs(i, 0, 0);
    351       memcpy(output_p, mean_p, sizeof(float) * channels);
    352       int64 copied = 1;
    353       while (copied < image_size) {
    354         // Repeatedly increases the number of elements to copy so they have
    355         // better vectorinizations. However, the source of the copy has to be
    356         // not too large to stay in the cache.
    357         const int64 kMaxToCopy = 1024;
    358         int64 to_copy = std::min({copied, image_size - copied, kMaxToCopy});
    359         memcpy(output_p + channels * copied, output_p,
    360                to_copy * channels * sizeof(float));
    361         copied += to_copy;
    362       }
    363     }
    364   }
    365 
    366   // Increment the outputs with the scaled difference between inputs and
    367   // outputs. Effectively: outputs += factor * (inputs - outputs).
    368   void IncrementWithScaling(typename TTypes<float, 3>::ConstTensor input,
    369                             typename TTypes<float>::ConstScalar factor,
    370                             typename TTypes<float, 3>::Tensor output) {
    371     const float factor_value = factor();
    372     float* p = output.data();
    373     const float* q = input.data();
    374     for (int64 n = 0; n < input.size(); ++n) {
    375       p[n] += factor_value * (q[n] - p[n]);
    376     }
    377   }
    378 };
    379 
    380 REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_CPU),
    381                         AdjustContrastOpv2<CPUDevice>);
    382 
    383 #if GOOGLE_CUDA
    384 // Forward declarations of the function specializations for GPU (to prevent
    385 // building the GPU versions here, they will be built compiling _gpu.cu.cc).
    386 namespace functor {
    387 template <>
    388 void AdjustContrastv2<GPUDevice>::operator()(
    389     const GPUDevice& d, typename TTypes<float, 4>::ConstTensor input,
    390     typename TTypes<float>::ConstScalar contrast_factor,
    391     typename TTypes<float, 4>::Tensor output);
    392 extern template struct AdjustContrastv2<GPUDevice>;
    393 }  // namespace functor
    394 
    395 template <>
    396 class AdjustContrastOpv2<GPUDevice> : public AdjustContrastOpV2Base {
    397  public:
    398   explicit AdjustContrastOpv2(OpKernelConstruction* context)
    399       : AdjustContrastOpV2Base(context) {}
    400 
    401   void DoCompute(OpKernelContext* context,
    402                  const ComputeOptions& options) override {
    403     const int64 shape[4] = {options.batch, options.height, options.width,
    404                             options.channels};
    405     functor::AdjustContrastv2<GPUDevice>()(
    406         context->eigen_device<GPUDevice>(),
    407         options.input->shaped<float, 4>(shape), options.factor->scalar<float>(),
    408         options.output->shaped<float, 4>(shape));
    409   }
    410 };
    411 
    412 REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU),
    413                         AdjustContrastOpv2<GPUDevice>);
    414 #endif  // GOOGLE_CUDA
    415 
    416 #ifdef TENSORFLOW_USE_SYCL
    417 template <>
    418 class AdjustContrastOpv2<SYCLDevice> : public AdjustContrastOpV2Base {
    419  public:
    420   explicit AdjustContrastOpv2(OpKernelConstruction* context)
    421       : AdjustContrastOpV2Base(context) {}
    422 
    423   void DoCompute(OpKernelContext* context,
    424                  const ComputeOptions& options) override {
    425     const int64 shape[4] = {options.batch, options.height, options.width,
    426                             options.channels};
    427     functor::AdjustContrastv2<SYCLDevice>()(
    428         context->eigen_device<SYCLDevice>(),
    429         options.input->shaped<float, 4>(shape), options.factor->scalar<float>(),
    430         options.output->shaped<float, 4>(shape));
    431   }
    432 };
    433 REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL),
    434                         AdjustContrastOpv2<SYCLDevice>);
    435 #endif  // TENSORFLOW_USE_SYCL
    436 
    437 }  // namespace tensorflow
    438