Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 Licensed under the Apache License, Version 2.0 (the "License");
      3 you may not use this file except in compliance with the License.
      4 You may obtain a copy of the License at
      5 
      6     http://www.apache.org/licenses/LICENSE-2.0
      7 
      8 Unless required by applicable law or agreed to in writing, software
      9 distributed under the License is distributed on an "AS IS" BASIS,
     10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 See the License for the specific language governing permissions and
     12 limitations under the License.
     13 ==============================================================================*/
     14 #define EIGEN_USE_THREADS
     15 
     16 #if GOOGLE_CUDA
     17 #define EIGEN_USE_GPU
     18 #endif
     19 
     20 #include <memory>
     21 
     22 #include "tensorflow/core/kernels/adjust_hue_op.h"
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/framework/types.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/util/work_sharder.h"
     34 
     35 namespace tensorflow {
     36 
     37 typedef Eigen::ThreadPoolDevice CPUDevice;
     38 typedef Eigen::GpuDevice GPUDevice;
     39 
     40 class AdjustHueOpBase : public OpKernel {
     41  protected:
     42   explicit AdjustHueOpBase(OpKernelConstruction* context) : OpKernel(context) {}
     43 
     44   struct ComputeOptions {
     45     const Tensor* input;
     46     const Tensor* delta;
     47     Tensor* output;
     48     int64 channel_count;
     49   };
     50 
     51   virtual void DoCompute(OpKernelContext* context,
     52                          const ComputeOptions& options) = 0;
     53 
     54   void Compute(OpKernelContext* context) override {
     55     const Tensor& input = context->input(0);
     56     const Tensor& delta = context->input(1);
     57     OP_REQUIRES(context, input.dims() >= 3,
     58                 errors::InvalidArgument("input must be at least 3-D, got shape",
     59                                         input.shape().DebugString()));
     60     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta.shape()),
     61                 errors::InvalidArgument("delta must be scalar: ",
     62                                         delta.shape().DebugString()));
     63     auto channels = input.dim_size(input.dims() - 1);
     64     OP_REQUIRES(
     65         context, channels == 3,
     66         errors::InvalidArgument("input must have 3 channels but instead has ",
     67                                 channels, " channels."));
     68 
     69     Tensor* output = nullptr;
     70     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
     71                                 {0}, 0, input.shape(), &output));
     72 
     73     if (input.NumElements() > 0) {
     74       const int64 channel_count = input.NumElements() / channels;
     75       ComputeOptions options;
     76       options.input = &input;
     77       options.delta = &delta;
     78       options.output = output;
     79       options.channel_count = channel_count;
     80       DoCompute(context, options);
     81     }
     82   }
     83 };
     84 
     85 template <class Device>
     86 class AdjustHueOp;
     87 
     88 namespace internal {
     89 
     90 // Helper function to convert a RGB color to H-and-V-range. H is in the range
     91 // of [0, 6] instead of the normal [0, 1]
     92 static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min,
     93                             float* v_max) {
     94   float v_mid;
     95   int h_category;
     96   // According to the figures in:
     97   // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
     98   // For the conditions, we don't care about the case where two components are
     99   // equal. It is okay to count it in either side in that case.
    100   if (r < g) {
    101     if (b < r) {
    102       // b < r < g
    103       *v_max = g;
    104       v_mid = r;
    105       *v_min = b;
    106       h_category = 1;
    107     } else if (b > g) {
    108       // r < g < b
    109       *v_max = b;
    110       v_mid = g;
    111       *v_min = r;
    112       h_category = 3;
    113     } else {
    114       // r < b < g
    115       *v_max = g;
    116       v_mid = b;
    117       *v_min = r;
    118       h_category = 2;
    119     }
    120   } else {
    121     // g < r
    122     if (b < g) {
    123       // b < g < r
    124       *v_max = r;
    125       v_mid = g;
    126       *v_min = b;
    127       h_category = 0;
    128     } else if (b > r) {
    129       // g < r < b
    130       *v_max = b;
    131       v_mid = r;
    132       *v_min = g;
    133       h_category = 4;
    134     } else {
    135       // g < b < r
    136       *v_max = r;
    137       v_mid = b;
    138       *v_min = g;
    139       h_category = 5;
    140     }
    141   }
    142   if (*v_max == *v_min) {
    143     *h = 0;
    144     return;
    145   }
    146   auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
    147   bool increase = ((h_category & 0x1) == 0);
    148   *h = h_category + (increase ? ratio : (1 - ratio));
    149 }
    150 
    151 // Helper function to convert from H-and-V-range to RGB.
    152 static void hv_range_to_rgb(float h, float v_min, float v_max, float* r,
    153                             float* g, float* b) {
    154   int h_category = static_cast<int>(h);
    155   float ratio = h - h_category;
    156   bool increase = ((h_category & 0x1) == 0);
    157   if (!increase) {
    158     ratio = 1 - ratio;
    159   }
    160   float v_mid = v_min + ratio * (v_max - v_min);
    161   // According to the figures in:
    162   // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
    163   switch (h_category) {
    164     case 0:
    165       *r = v_max;
    166       *g = v_mid;
    167       *b = v_min;
    168       break;
    169     case 1:
    170       *r = v_mid;
    171       *g = v_max;
    172       *b = v_min;
    173       break;
    174     case 2:
    175       *r = v_min;
    176       *g = v_max;
    177       *b = v_mid;
    178       break;
    179     case 3:
    180       *r = v_min;
    181       *g = v_mid;
    182       *b = v_max;
    183       break;
    184     case 4:
    185       *r = v_mid;
    186       *g = v_min;
    187       *b = v_max;
    188       break;
    189     case 5:
    190     default:
    191       *r = v_max;
    192       *g = v_min;
    193       *b = v_mid;
    194   }
    195 }
    196 }  // namespace internal
    197 
    198 template <>
    199 class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
    200  public:
    201   explicit AdjustHueOp(OpKernelConstruction* context)
    202       : AdjustHueOpBase(context) {}
    203 
    204   void DoCompute(OpKernelContext* context,
    205                  const ComputeOptions& options) override {
    206     const Tensor* input = options.input;
    207     const Tensor* delta = options.delta;
    208     Tensor* output = options.output;
    209     const int64 channel_count = options.channel_count;
    210     static const int kChannelSize = 3;
    211     auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
    212     const float delta_h = delta->scalar<float>()();
    213     auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
    214     const int kCostPerChannel = 10;
    215     const DeviceBase::CpuWorkerThreads& worker_threads =
    216         *context->device()->tensorflow_cpu_worker_threads();
    217     Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
    218           kCostPerChannel,
    219           [channel_count, &input_data, &output_data, delta_h](
    220               int64 start_channel, int64 end_channel) {
    221             const float* p = input_data.data() + start_channel * kChannelSize;
    222             float* q = output_data.data() + start_channel * kChannelSize;
    223             for (int i = start_channel; i < end_channel; i++) {
    224               float h, v_min, v_max;
    225               // Convert the RGB color to Hue/V-range.
    226               internal::rgb_to_hv_range(p[0], p[1], p[2], &h, &v_min, &v_max);
    227               static const int kChannelRange = 6;
    228               // Adjust the hue value. And adjust the hue back into the valid
    229               // range of [0, 6). It is faster than a fmod by avoiding
    230               // a float-point division since h is often very close to this
    231               // range.
    232               h += delta_h * kChannelRange;
    233               while (h < 0) {
    234                 h += kChannelRange;
    235               }
    236               while (h >= kChannelRange) {
    237                 h -= kChannelRange;
    238               }
    239               // Convert the hue and v-range back into RGB.
    240               internal::hv_range_to_rgb(h, v_min, v_max, q, q + 1, q + 2);
    241               p += kChannelSize;
    242               q += kChannelSize;
    243             }
    244           });
    245   }
    246 };
    247 
    248 REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU),
    249                         AdjustHueOp<CPUDevice>);
    250 
    251 #if GOOGLE_CUDA
    252 template <>
    253 class AdjustHueOp<GPUDevice> : public AdjustHueOpBase {
    254  public:
    255   explicit AdjustHueOp(OpKernelConstruction* context)
    256       : AdjustHueOpBase(context) {}
    257 
    258   void DoCompute(OpKernelContext* context,
    259                  const ComputeOptions& options) override {
    260     const Tensor* input = options.input;
    261     const Tensor* delta = options.delta;
    262     Tensor* output = options.output;
    263     const int64 number_of_elements = input->NumElements();
    264     GPUDevice device = context->eigen_gpu_device();
    265     const auto stream = device.stream();
    266     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    267     if (number_of_elements > 0) {
    268       const float* input_data = input->flat<float>().data();
    269       const float* delta_h = delta->flat<float>().data();
    270       float* const output_data = output->flat<float>().data();
    271       functor::AdjustHueGPU()(&device, number_of_elements, input_data, delta_h,
    272                               output_data);
    273     }
    274   }
    275 };
    276 
    277 REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_GPU),
    278                         AdjustHueOp<GPUDevice>);
    279 
    280 #endif
    281 
    282 //} // namespace functor
    283 }  // namespace tensorflow
    284