Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 #define EIGEN_USE_THREADS
     18 
     19 #include <algorithm>
     20 #include <cmath>
     21 
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/tensor_types.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/colorspace_op.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace tensorflow {
     35 
     36 typedef Eigen::ThreadPoolDevice CPUDevice;
     37 typedef Eigen::GpuDevice GPUDevice;
     38 #ifdef TENSORFLOW_USE_SYCL
     39 typedef Eigen::SyclDevice SYCLDevice;
     40 #endif
     41 
     42 template <typename Device, typename T>
     43 class RGBToHSVOp : public OpKernel {
     44  public:
     45   explicit RGBToHSVOp(OpKernelConstruction* context) : OpKernel(context) {}
     46 
     47   void Compute(OpKernelContext* context) override {
     48     const Tensor& input = context->input(0);
     49     OP_REQUIRES(context, input.dims() >= 1,
     50                 errors::InvalidArgument("input must be at least 1D",
     51                                         input.shape().DebugString()));
     52     auto channels = input.dim_size(input.dims() - 1);
     53     OP_REQUIRES(context, channels == 3,
     54                 errors::FailedPrecondition(
     55                     "input must have 3 channels but input only has ", channels,
     56                     " channels."));
     57 
     58     // Create the output Tensor with the same dimensions as the input Tensor.
     59     Tensor* output = nullptr;
     60     OP_REQUIRES_OK(context,
     61                    context->allocate_output(0, input.shape(), &output));
     62 
     63     // Make a canonical image, maintaining the last (channel) dimension, while
     64     // flattening all others do give the functor easy to work with data.
     65     typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>();
     66     typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>();
     67 
     68     Tensor trange;
     69     OP_REQUIRES_OK(
     70         context, context->allocate_temp(DataTypeToEnum<T>::value,
     71                                         TensorShape({input_data.dimension(0)}),
     72                                         &trange));
     73 
     74     typename TTypes<T, 1>::Tensor range(trange.tensor<T, 1>());
     75 
     76     functor::RGBToHSV<Device, T>()(context->eigen_device<Device>(), input_data,
     77                                    range, output_data);
     78   }
     79 };
     80 
     81 template <typename Device, typename T>
     82 class HSVToRGBOp : public OpKernel {
     83  public:
     84   explicit HSVToRGBOp(OpKernelConstruction* context) : OpKernel(context) {}
     85 
     86   void Compute(OpKernelContext* context) override {
     87     const Tensor& input = context->input(0);
     88     OP_REQUIRES(context, input.dims() >= 1,
     89                 errors::InvalidArgument("input must be at least 1D",
     90                                         input.shape().DebugString()));
     91     auto channels = input.dim_size(input.dims() - 1);
     92     OP_REQUIRES(context, channels == 3,
     93                 errors::FailedPrecondition(
     94                     "input must have 3 channels but input only has ", channels,
     95                     " channels."));
     96 
     97     // Create the output Tensor with the same dimensions as the input Tensor.
     98     Tensor* output = nullptr;
     99     OP_REQUIRES_OK(context,
    100                    context->allocate_output(0, input.shape(), &output));
    101 
    102     typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>();
    103     typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>();
    104 
    105     functor::HSVToRGB<Device, T>()(context->eigen_device<Device>(), input_data,
    106                                    output_data);
    107   }
    108 };
    109 
    110 #define REGISTER_CPU(T)                                           \
    111   REGISTER_KERNEL_BUILDER(                                        \
    112       Name("RGBToHSV").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    113       RGBToHSVOp<CPUDevice, T>);                                  \
    114   template class RGBToHSVOp<CPUDevice, T>;                        \
    115   REGISTER_KERNEL_BUILDER(                                        \
    116       Name("HSVToRGB").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    117       HSVToRGBOp<CPUDevice, T>);                                  \
    118   template class HSVToRGBOp<CPUDevice, T>;
    119 TF_CALL_float(REGISTER_CPU);
    120 TF_CALL_double(REGISTER_CPU);
    121 
    122 #if GOOGLE_CUDA
    123 // Forward declarations of the function specializations for GPU (to prevent
    124 // building the GPU versions here, they will be built compiling _gpu.cu.cc).
    125 namespace functor {
    126 #define DECLARE_GPU(T)                                               \
    127   template <>                                                        \
    128   void RGBToHSV<GPUDevice, T>::operator()(                           \
    129       const GPUDevice& d, TTypes<T, 2>::ConstTensor input_data,      \
    130       TTypes<T, 1>::Tensor range, TTypes<T, 2>::Tensor output_data); \
    131   extern template struct RGBToHSV<GPUDevice, T>;                     \
    132   template <>                                                        \
    133   void HSVToRGB<GPUDevice, T>::operator()(                           \
    134       const GPUDevice& d, TTypes<T, 2>::ConstTensor input_data,      \
    135       TTypes<T, 2>::Tensor output_data);                             \
    136   extern template struct HSVToRGB<GPUDevice, T>;
    137 TF_CALL_float(DECLARE_GPU);
    138 TF_CALL_double(DECLARE_GPU);
    139 }  // namespace functor
    140 #define REGISTER_GPU(T)                                           \
    141   REGISTER_KERNEL_BUILDER(                                        \
    142       Name("RGBToHSV").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    143       RGBToHSVOp<GPUDevice, T>);                                  \
    144   REGISTER_KERNEL_BUILDER(                                        \
    145       Name("HSVToRGB").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    146       HSVToRGBOp<GPUDevice, T>);
    147 TF_CALL_float(REGISTER_GPU);
    148 TF_CALL_double(REGISTER_GPU);
    149 #endif
    150 
    151 #ifdef TENSORFLOW_USE_SYCL
    152 #define REGISTER_SYCL(T)                                           \
    153   REGISTER_KERNEL_BUILDER(                                         \
    154       Name("RGBToHSV").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
    155       RGBToHSVOp<SYCLDevice, T>);                                  \
    156   REGISTER_KERNEL_BUILDER(                                         \
    157       Name("HSVToRGB").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
    158       HSVToRGBOp<SYCLDevice, T>);
    159 TF_CALL_float(REGISTER_SYCL);
    160 TF_CALL_double(REGISTER_SYCL);
    161 #endif
    162 
    163 }  // namespace tensorflow
    164