Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #if GOOGLE_CUDA
     18 #define EIGEN_USE_GPU
     19 #endif
     20 
     21 #include "tensorflow/core/kernels/adjust_saturation_op.h"
     22 #include <memory>
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/util/work_sharder.h"
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 typedef Eigen::GpuDevice GPUDevice;
     37 
     38 class AdjustSaturationOpBase : public OpKernel {
     39  protected:
     40   explicit AdjustSaturationOpBase(OpKernelConstruction* context)
     41       : OpKernel(context) {}
     42 
     43   struct ComputeOptions {
     44     const Tensor* input;
     45     const Tensor* scale;
     46     Tensor* output;
     47     int64 channel_count;
     48   };
     49 
     50   virtual void DoCompute(OpKernelContext* context,
     51                          const ComputeOptions& options) = 0;
     52 
     53   void Compute(OpKernelContext* context) override {
     54     const Tensor& input = context->input(0);
     55     const Tensor& scale = context->input(1);
     56     OP_REQUIRES(context, input.dims() >= 3,
     57                 errors::InvalidArgument("input must be at least 3-D, got shape",
     58                                         input.shape().DebugString()));
     59     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale.shape()),
     60                 errors::InvalidArgument("scale must be scalar: ",
     61                                         scale.shape().DebugString()));
     62     auto channels = input.dim_size(input.dims() - 1);
     63     OP_REQUIRES(
     64         context, channels == 3,
     65         errors::InvalidArgument("input must have 3 channels but instead has ",
     66                                 channels, " channels."));
     67 
     68     Tensor* output = nullptr;
     69     OP_REQUIRES_OK(context,
     70                    context->allocate_output(0, input.shape(), &output));
     71 
     72     if (input.NumElements() > 0) {
     73       const int64 channel_count = input.NumElements() / channels;
     74       ComputeOptions options;
     75       options.input = &input;
     76       options.scale = &scale;
     77       options.output = output;
     78       options.channel_count = channel_count;
     79       DoCompute(context, options);
     80     }
     81   }
     82 };
     83 
     84 template <class Device>
     85 class AdjustSaturationOp;
     86 
     87 namespace internal {
     88 static void rgb_to_hsv(float r, float g, float b, float* h, float* s,
     89                        float* v) {
     90   float vv = std::max(r, std::max(g, b));
     91   float range = vv - std::min(r, std::min(g, b));
     92   if (vv > 0) {
     93     *s = range / vv;
     94   } else {
     95     *s = 0;
     96   }
     97   float norm = 1.0f / (6.0f * range);
     98   float hh;
     99   if (r == vv) {
    100     hh = norm * (g - b);
    101   } else if (g == vv) {
    102     hh = norm * (b - r) + 2.0 / 6.0;
    103   } else {
    104     hh = norm * (r - g) + 4.0 / 6.0;
    105   }
    106   if (range <= 0.0) {
    107     hh = 0;
    108   }
    109   if (hh < 0.0) {
    110     hh = hh + 1;
    111   }
    112   *v = vv;
    113   *h = hh;
    114 }
    115 
    116 // Algorithm from wikipedia, https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV
    117 static void hsv_to_rgb(float h, float s, float v, float* r, float* g,
    118                        float* b) {
    119   float c = s * v;
    120   float m = v - c;
    121   float dh = h * 6;
    122   float rr, gg, bb;
    123   int h_category = static_cast<int>(dh);
    124   float fmodu = dh;
    125   while (fmodu <= 0) {
    126     fmodu += 2.0f;
    127   }
    128   while (fmodu >= 2.0f) {
    129     fmodu -= 2.0f;
    130   }
    131   float x = c * (1 - std::abs(fmodu - 1));
    132   switch (h_category) {
    133     case 0:
    134       rr = c;
    135       gg = x;
    136       bb = 0;
    137       break;
    138     case 1:
    139       rr = x;
    140       gg = c;
    141       bb = 0;
    142       break;
    143     case 2:
    144       rr = 0;
    145       gg = c;
    146       bb = x;
    147       break;
    148     case 3:
    149       rr = 0;
    150       gg = x;
    151       bb = c;
    152       break;
    153     case 4:
    154       rr = x;
    155       gg = 0;
    156       bb = c;
    157       break;
    158     case 5:
    159       rr = c;
    160       gg = 0;
    161       bb = x;
    162       break;
    163     default:
    164       rr = 0;
    165       gg = 0;
    166       bb = 0;
    167   }
    168   *r = rr + m;
    169   *g = gg + m;
    170   *b = bb + m;
    171 }
    172 
    173 }  // namespace internal
    174 
    175 template <>
    176 class AdjustSaturationOp<CPUDevice> : public AdjustSaturationOpBase {
    177  public:
    178   explicit AdjustSaturationOp(OpKernelConstruction* context)
    179       : AdjustSaturationOpBase(context) {}
    180 
    181   void DoCompute(OpKernelContext* context,
    182                  const ComputeOptions& options) override {
    183     const Tensor* input = options.input;
    184     const Tensor* scale = options.scale;
    185     Tensor* output = options.output;
    186     const int64 channel_count = options.channel_count;
    187     static const int kChannelSize = 3;
    188     auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
    189     const float scale_h = scale->scalar<float>()();
    190     auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
    191     const int kCostPerChannel = 10;
    192     const DeviceBase::CpuWorkerThreads& worker_threads =
    193         *context->device()->tensorflow_cpu_worker_threads();
    194     Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
    195           kCostPerChannel,
    196           [channel_count, &input_data, &output_data, scale_h](
    197               int64 start_channel, int64 end_channel) {
    198             const float* p = input_data.data() + start_channel * kChannelSize;
    199             float* q = output_data.data() + start_channel * kChannelSize;
    200             for (int i = start_channel; i < end_channel; i++) {
    201               float h, s, v;
    202               // Convert the RGB color to Hue/V-range.
    203               internal::rgb_to_hsv(p[0], p[1], p[2], &h, &s, &v);
    204               s = std::min(1.0f, std::max(0.0f, s * scale_h));
    205               // Convert the hue and v-range back into RGB.
    206               internal::hsv_to_rgb(h, s, v, q, q + 1, q + 2);
    207               p += kChannelSize;
    208               q += kChannelSize;
    209             }
    210           });
    211   }
    212 };
    213 
    214 REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_CPU),
    215                         AdjustSaturationOp<CPUDevice>);
    216 
    217 #if GOOGLE_CUDA
    218 template <>
    219 class AdjustSaturationOp<GPUDevice> : public AdjustSaturationOpBase {
    220  public:
    221   explicit AdjustSaturationOp(OpKernelConstruction* context)
    222       : AdjustSaturationOpBase(context) {}
    223 
    224   void DoCompute(OpKernelContext* context,
    225                  const ComputeOptions& options) override {
    226     const Tensor* input = options.input;
    227     const Tensor* scale = options.scale;
    228     Tensor* output = options.output;
    229     const int64 number_of_elements = input->NumElements();
    230     GPUDevice device = context->eigen_gpu_device();
    231     const auto stream = device.stream();
    232     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    233     if (number_of_elements > 0) {
    234       const float* input_data = input->flat<float>().data();
    235       const float* scale_data = scale->flat<float>().data();
    236       float* const output_data = output->flat<float>().data();
    237       functor::AdjustSaturationGPU()(&device, number_of_elements, input_data,
    238                                      scale_data, output_data);
    239     }
    240   }
    241 };
    242 
    243 REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_GPU),
    244                         AdjustSaturationOp<GPUDevice>);
    245 
    246 #endif
    247 
    248 }  // namespace tensorflow
    249