1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 #define EIGEN_USE_THREADS 18 19 #include <algorithm> 20 #include <cmath> 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/colorspace_op.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 typedef Eigen::GpuDevice GPUDevice; 38 #ifdef TENSORFLOW_USE_SYCL 39 typedef Eigen::SyclDevice SYCLDevice; 40 #endif 41 42 template <typename Device, typename T> 43 class RGBToHSVOp : public OpKernel { 44 public: 45 explicit RGBToHSVOp(OpKernelConstruction* context) : OpKernel(context) {} 46 47 void Compute(OpKernelContext* context) override { 48 const Tensor& input = context->input(0); 49 OP_REQUIRES(context, input.dims() >= 1, 50 errors::InvalidArgument("input must be at least 1D", 51 input.shape().DebugString())); 52 auto channels = input.dim_size(input.dims() - 1); 53 OP_REQUIRES(context, channels == 3, 54 errors::FailedPrecondition( 55 "input must have 3 channels but input only has ", channels, 56 " channels.")); 57 58 // Create the output Tensor with the same dimensions as the input Tensor. 59 Tensor* output = nullptr; 60 OP_REQUIRES_OK(context, 61 context->allocate_output(0, input.shape(), &output)); 62 63 // Make a canonical image, maintaining the last (channel) dimension, while 64 // flattening all others do give the functor easy to work with data. 65 typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>(); 66 typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>(); 67 68 Tensor trange; 69 OP_REQUIRES_OK( 70 context, context->allocate_temp(DataTypeToEnum<T>::value, 71 TensorShape({input_data.dimension(0)}), 72 &trange)); 73 74 typename TTypes<T, 1>::Tensor range(trange.tensor<T, 1>()); 75 76 functor::RGBToHSV<Device, T>()(context->eigen_device<Device>(), input_data, 77 range, output_data); 78 } 79 }; 80 81 template <typename Device, typename T> 82 class HSVToRGBOp : public OpKernel { 83 public: 84 explicit HSVToRGBOp(OpKernelConstruction* context) : OpKernel(context) {} 85 86 void Compute(OpKernelContext* context) override { 87 const Tensor& input = context->input(0); 88 OP_REQUIRES(context, input.dims() >= 1, 89 errors::InvalidArgument("input must be at least 1D", 90 input.shape().DebugString())); 91 auto channels = input.dim_size(input.dims() - 1); 92 OP_REQUIRES(context, channels == 3, 93 errors::FailedPrecondition( 94 "input must have 3 channels but input only has ", channels, 95 " channels.")); 96 97 // Create the output Tensor with the same dimensions as the input Tensor. 98 Tensor* output = nullptr; 99 OP_REQUIRES_OK(context, 100 context->allocate_output(0, input.shape(), &output)); 101 102 typename TTypes<T, 2>::ConstTensor input_data = input.flat_inner_dims<T>(); 103 typename TTypes<T, 2>::Tensor output_data = output->flat_inner_dims<T>(); 104 105 functor::HSVToRGB<Device, T>()(context->eigen_device<Device>(), input_data, 106 output_data); 107 } 108 }; 109 110 #define REGISTER_CPU(T) \ 111 REGISTER_KERNEL_BUILDER( \ 112 Name("RGBToHSV").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 113 RGBToHSVOp<CPUDevice, T>); \ 114 template class RGBToHSVOp<CPUDevice, T>; \ 115 REGISTER_KERNEL_BUILDER( \ 116 Name("HSVToRGB").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 117 HSVToRGBOp<CPUDevice, T>); \ 118 template class HSVToRGBOp<CPUDevice, T>; 119 TF_CALL_float(REGISTER_CPU); 120 TF_CALL_double(REGISTER_CPU); 121 122 #if GOOGLE_CUDA 123 // Forward declarations of the function specializations for GPU (to prevent 124 // building the GPU versions here, they will be built compiling _gpu.cu.cc). 125 namespace functor { 126 #define DECLARE_GPU(T) \ 127 template <> \ 128 void RGBToHSV<GPUDevice, T>::operator()( \ 129 const GPUDevice& d, TTypes<T, 2>::ConstTensor input_data, \ 130 TTypes<T, 1>::Tensor range, TTypes<T, 2>::Tensor output_data); \ 131 extern template struct RGBToHSV<GPUDevice, T>; \ 132 template <> \ 133 void HSVToRGB<GPUDevice, T>::operator()( \ 134 const GPUDevice& d, TTypes<T, 2>::ConstTensor input_data, \ 135 TTypes<T, 2>::Tensor output_data); \ 136 extern template struct HSVToRGB<GPUDevice, T>; 137 TF_CALL_float(DECLARE_GPU); 138 TF_CALL_double(DECLARE_GPU); 139 } // namespace functor 140 #define REGISTER_GPU(T) \ 141 REGISTER_KERNEL_BUILDER( \ 142 Name("RGBToHSV").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 143 RGBToHSVOp<GPUDevice, T>); \ 144 REGISTER_KERNEL_BUILDER( \ 145 Name("HSVToRGB").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 146 HSVToRGBOp<GPUDevice, T>); 147 TF_CALL_float(REGISTER_GPU); 148 TF_CALL_double(REGISTER_GPU); 149 #endif 150 151 #ifdef TENSORFLOW_USE_SYCL 152 #define REGISTER_SYCL(T) \ 153 REGISTER_KERNEL_BUILDER( \ 154 Name("RGBToHSV").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \ 155 RGBToHSVOp<SYCLDevice, T>); \ 156 REGISTER_KERNEL_BUILDER( \ 157 Name("HSVToRGB").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \ 158 HSVToRGBOp<SYCLDevice, T>); 159 TF_CALL_float(REGISTER_SYCL); 160 TF_CALL_double(REGISTER_SYCL); 161 #endif 162 163 } // namespace tensorflow 164