1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 Licensed under the Apache License, Version 2.0 (the "License"); 3 you may not use this file except in compliance with the License. 4 You may obtain a copy of the License at 5 6 http://www.apache.org/licenses/LICENSE-2.0 7 8 Unless required by applicable law or agreed to in writing, software 9 distributed under the License is distributed on an "AS IS" BASIS, 10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 See the License for the specific language governing permissions and 12 limitations under the License. 13 ==============================================================================*/ 14 #define EIGEN_USE_THREADS 15 16 #if GOOGLE_CUDA 17 #define EIGEN_USE_GPU 18 #endif 19 20 #include <memory> 21 22 #include "tensorflow/core/kernels/adjust_hue_op.h" 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 typedef Eigen::ThreadPoolDevice CPUDevice; 38 typedef Eigen::GpuDevice GPUDevice; 39 40 class AdjustHueOpBase : public OpKernel { 41 protected: 42 explicit AdjustHueOpBase(OpKernelConstruction* context) : OpKernel(context) {} 43 44 struct ComputeOptions { 45 const Tensor* input; 46 const Tensor* delta; 47 Tensor* output; 48 int64 channel_count; 49 }; 50 51 virtual void DoCompute(OpKernelContext* context, 52 const ComputeOptions& options) = 0; 53 54 void Compute(OpKernelContext* context) override { 55 const Tensor& input = context->input(0); 56 const Tensor& delta = context->input(1); 57 OP_REQUIRES(context, input.dims() >= 3, 58 errors::InvalidArgument("input must be at least 3-D, got shape", 59 input.shape().DebugString())); 60 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta.shape()), 61 errors::InvalidArgument("delta must be scalar: ", 62 delta.shape().DebugString())); 63 auto channels = input.dim_size(input.dims() - 1); 64 OP_REQUIRES( 65 context, channels == 3, 66 errors::InvalidArgument("input must have 3 channels but instead has ", 67 channels, " channels.")); 68 69 Tensor* output = nullptr; 70 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 71 {0}, 0, input.shape(), &output)); 72 73 if (input.NumElements() > 0) { 74 const int64 channel_count = input.NumElements() / channels; 75 ComputeOptions options; 76 options.input = &input; 77 options.delta = δ 78 options.output = output; 79 options.channel_count = channel_count; 80 DoCompute(context, options); 81 } 82 } 83 }; 84 85 template <class Device> 86 class AdjustHueOp; 87 88 namespace internal { 89 90 // Helper function to convert a RGB color to H-and-V-range. H is in the range 91 // of [0, 6] instead of the normal [0, 1] 92 static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min, 93 float* v_max) { 94 float v_mid; 95 int h_category; 96 // According to the figures in: 97 // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma 98 // For the conditions, we don't care about the case where two components are 99 // equal. It is okay to count it in either side in that case. 100 if (r < g) { 101 if (b < r) { 102 // b < r < g 103 *v_max = g; 104 v_mid = r; 105 *v_min = b; 106 h_category = 1; 107 } else if (b > g) { 108 // r < g < b 109 *v_max = b; 110 v_mid = g; 111 *v_min = r; 112 h_category = 3; 113 } else { 114 // r < b < g 115 *v_max = g; 116 v_mid = b; 117 *v_min = r; 118 h_category = 2; 119 } 120 } else { 121 // g < r 122 if (b < g) { 123 // b < g < r 124 *v_max = r; 125 v_mid = g; 126 *v_min = b; 127 h_category = 0; 128 } else if (b > r) { 129 // g < r < b 130 *v_max = b; 131 v_mid = r; 132 *v_min = g; 133 h_category = 4; 134 } else { 135 // g < b < r 136 *v_max = r; 137 v_mid = b; 138 *v_min = g; 139 h_category = 5; 140 } 141 } 142 if (*v_max == *v_min) { 143 *h = 0; 144 return; 145 } 146 auto ratio = (v_mid - *v_min) / (*v_max - *v_min); 147 bool increase = ((h_category & 0x1) == 0); 148 *h = h_category + (increase ? ratio : (1 - ratio)); 149 } 150 151 // Helper function to convert from H-and-V-range to RGB. 152 static void hv_range_to_rgb(float h, float v_min, float v_max, float* r, 153 float* g, float* b) { 154 int h_category = static_cast<int>(h); 155 float ratio = h - h_category; 156 bool increase = ((h_category & 0x1) == 0); 157 if (!increase) { 158 ratio = 1 - ratio; 159 } 160 float v_mid = v_min + ratio * (v_max - v_min); 161 // According to the figures in: 162 // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma 163 switch (h_category) { 164 case 0: 165 *r = v_max; 166 *g = v_mid; 167 *b = v_min; 168 break; 169 case 1: 170 *r = v_mid; 171 *g = v_max; 172 *b = v_min; 173 break; 174 case 2: 175 *r = v_min; 176 *g = v_max; 177 *b = v_mid; 178 break; 179 case 3: 180 *r = v_min; 181 *g = v_mid; 182 *b = v_max; 183 break; 184 case 4: 185 *r = v_mid; 186 *g = v_min; 187 *b = v_max; 188 break; 189 case 5: 190 default: 191 *r = v_max; 192 *g = v_min; 193 *b = v_mid; 194 } 195 } 196 } // namespace internal 197 198 template <> 199 class AdjustHueOp<CPUDevice> : public AdjustHueOpBase { 200 public: 201 explicit AdjustHueOp(OpKernelConstruction* context) 202 : AdjustHueOpBase(context) {} 203 204 void DoCompute(OpKernelContext* context, 205 const ComputeOptions& options) override { 206 const Tensor* input = options.input; 207 const Tensor* delta = options.delta; 208 Tensor* output = options.output; 209 const int64 channel_count = options.channel_count; 210 static const int kChannelSize = 3; 211 auto input_data = input->shaped<float, 2>({channel_count, kChannelSize}); 212 const float delta_h = delta->scalar<float>()(); 213 auto output_data = output->shaped<float, 2>({channel_count, kChannelSize}); 214 const int kCostPerChannel = 10; 215 const DeviceBase::CpuWorkerThreads& worker_threads = 216 *context->device()->tensorflow_cpu_worker_threads(); 217 Shard(worker_threads.num_threads, worker_threads.workers, channel_count, 218 kCostPerChannel, 219 [channel_count, &input_data, &output_data, delta_h]( 220 int64 start_channel, int64 end_channel) { 221 const float* p = input_data.data() + start_channel * kChannelSize; 222 float* q = output_data.data() + start_channel * kChannelSize; 223 for (int i = start_channel; i < end_channel; i++) { 224 float h, v_min, v_max; 225 // Convert the RGB color to Hue/V-range. 226 internal::rgb_to_hv_range(p[0], p[1], p[2], &h, &v_min, &v_max); 227 static const int kChannelRange = 6; 228 // Adjust the hue value. And adjust the hue back into the valid 229 // range of [0, 6). It is faster than a fmod by avoiding 230 // a float-point division since h is often very close to this 231 // range. 232 h += delta_h * kChannelRange; 233 while (h < 0) { 234 h += kChannelRange; 235 } 236 while (h >= kChannelRange) { 237 h -= kChannelRange; 238 } 239 // Convert the hue and v-range back into RGB. 240 internal::hv_range_to_rgb(h, v_min, v_max, q, q + 1, q + 2); 241 p += kChannelSize; 242 q += kChannelSize; 243 } 244 }); 245 } 246 }; 247 248 REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU), 249 AdjustHueOp<CPUDevice>); 250 251 #if GOOGLE_CUDA 252 template <> 253 class AdjustHueOp<GPUDevice> : public AdjustHueOpBase { 254 public: 255 explicit AdjustHueOp(OpKernelConstruction* context) 256 : AdjustHueOpBase(context) {} 257 258 void DoCompute(OpKernelContext* context, 259 const ComputeOptions& options) override { 260 const Tensor* input = options.input; 261 const Tensor* delta = options.delta; 262 Tensor* output = options.output; 263 const int64 number_of_elements = input->NumElements(); 264 GPUDevice device = context->eigen_gpu_device(); 265 const auto stream = device.stream(); 266 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 267 if (number_of_elements > 0) { 268 const float* input_data = input->flat<float>().data(); 269 const float* delta_h = delta->flat<float>().data(); 270 float* const output_data = output->flat<float>().data(); 271 functor::AdjustHueGPU()(&device, number_of_elements, input_data, delta_h, 272 output_data); 273 } 274 } 275 }; 276 277 REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_GPU), 278 AdjustHueOp<GPUDevice>); 279 280 #endif 281 282 //} // namespace functor 283 } // namespace tensorflow 284