1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #define EIGEN_USE_THREADS 16 17 #if GOOGLE_CUDA 18 #define EIGEN_USE_GPU 19 #endif 20 21 #include "tensorflow/core/kernels/adjust_saturation_op.h" 22 #include <memory> 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/util/work_sharder.h" 32 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 38 class AdjustSaturationOpBase : public OpKernel { 39 protected: 40 explicit AdjustSaturationOpBase(OpKernelConstruction* context) 41 : OpKernel(context) {} 42 43 struct ComputeOptions { 44 const Tensor* input; 45 const Tensor* scale; 46 Tensor* output; 47 int64 channel_count; 48 }; 49 50 virtual void DoCompute(OpKernelContext* context, 51 const ComputeOptions& options) = 0; 52 53 void Compute(OpKernelContext* context) override { 54 const Tensor& input = context->input(0); 55 const Tensor& scale = context->input(1); 56 OP_REQUIRES(context, input.dims() >= 3, 57 errors::InvalidArgument("input must be at least 3-D, got shape", 58 input.shape().DebugString())); 59 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale.shape()), 60 errors::InvalidArgument("scale must be scalar: ", 61 scale.shape().DebugString())); 62 auto channels = input.dim_size(input.dims() - 1); 63 OP_REQUIRES( 64 context, channels == 3, 65 errors::InvalidArgument("input must have 3 channels but instead has ", 66 channels, " channels.")); 67 68 Tensor* output = nullptr; 69 OP_REQUIRES_OK(context, 70 context->allocate_output(0, input.shape(), &output)); 71 72 if (input.NumElements() > 0) { 73 const int64 channel_count = input.NumElements() / channels; 74 ComputeOptions options; 75 options.input = &input; 76 options.scale = &scale; 77 options.output = output; 78 options.channel_count = channel_count; 79 DoCompute(context, options); 80 } 81 } 82 }; 83 84 template <class Device> 85 class AdjustSaturationOp; 86 87 namespace internal { 88 static void rgb_to_hsv(float r, float g, float b, float* h, float* s, 89 float* v) { 90 float vv = std::max(r, std::max(g, b)); 91 float range = vv - std::min(r, std::min(g, b)); 92 if (vv > 0) { 93 *s = range / vv; 94 } else { 95 *s = 0; 96 } 97 float norm = 1.0f / (6.0f * range); 98 float hh; 99 if (r == vv) { 100 hh = norm * (g - b); 101 } else if (g == vv) { 102 hh = norm * (b - r) + 2.0 / 6.0; 103 } else { 104 hh = norm * (r - g) + 4.0 / 6.0; 105 } 106 if (range <= 0.0) { 107 hh = 0; 108 } 109 if (hh < 0.0) { 110 hh = hh + 1; 111 } 112 *v = vv; 113 *h = hh; 114 } 115 116 // Algorithm from wikipedia, https://en.wikipedia.org/wiki/HSL_and_HSV#From_HSV 117 static void hsv_to_rgb(float h, float s, float v, float* r, float* g, 118 float* b) { 119 float c = s * v; 120 float m = v - c; 121 float dh = h * 6; 122 float rr, gg, bb; 123 int h_category = static_cast<int>(dh); 124 float fmodu = dh; 125 while (fmodu <= 0) { 126 fmodu += 2.0f; 127 } 128 while (fmodu >= 2.0f) { 129 fmodu -= 2.0f; 130 } 131 float x = c * (1 - std::abs(fmodu - 1)); 132 switch (h_category) { 133 case 0: 134 rr = c; 135 gg = x; 136 bb = 0; 137 break; 138 case 1: 139 rr = x; 140 gg = c; 141 bb = 0; 142 break; 143 case 2: 144 rr = 0; 145 gg = c; 146 bb = x; 147 break; 148 case 3: 149 rr = 0; 150 gg = x; 151 bb = c; 152 break; 153 case 4: 154 rr = x; 155 gg = 0; 156 bb = c; 157 break; 158 case 5: 159 rr = c; 160 gg = 0; 161 bb = x; 162 break; 163 default: 164 rr = 0; 165 gg = 0; 166 bb = 0; 167 } 168 *r = rr + m; 169 *g = gg + m; 170 *b = bb + m; 171 } 172 173 } // namespace internal 174 175 template <> 176 class AdjustSaturationOp<CPUDevice> : public AdjustSaturationOpBase { 177 public: 178 explicit AdjustSaturationOp(OpKernelConstruction* context) 179 : AdjustSaturationOpBase(context) {} 180 181 void DoCompute(OpKernelContext* context, 182 const ComputeOptions& options) override { 183 const Tensor* input = options.input; 184 const Tensor* scale = options.scale; 185 Tensor* output = options.output; 186 const int64 channel_count = options.channel_count; 187 static const int kChannelSize = 3; 188 auto input_data = input->shaped<float, 2>({channel_count, kChannelSize}); 189 const float scale_h = scale->scalar<float>()(); 190 auto output_data = output->shaped<float, 2>({channel_count, kChannelSize}); 191 const int kCostPerChannel = 10; 192 const DeviceBase::CpuWorkerThreads& worker_threads = 193 *context->device()->tensorflow_cpu_worker_threads(); 194 Shard(worker_threads.num_threads, worker_threads.workers, channel_count, 195 kCostPerChannel, 196 [channel_count, &input_data, &output_data, scale_h]( 197 int64 start_channel, int64 end_channel) { 198 const float* p = input_data.data() + start_channel * kChannelSize; 199 float* q = output_data.data() + start_channel * kChannelSize; 200 for (int i = start_channel; i < end_channel; i++) { 201 float h, s, v; 202 // Convert the RGB color to Hue/V-range. 203 internal::rgb_to_hsv(p[0], p[1], p[2], &h, &s, &v); 204 s = std::min(1.0f, std::max(0.0f, s * scale_h)); 205 // Convert the hue and v-range back into RGB. 206 internal::hsv_to_rgb(h, s, v, q, q + 1, q + 2); 207 p += kChannelSize; 208 q += kChannelSize; 209 } 210 }); 211 } 212 }; 213 214 REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_CPU), 215 AdjustSaturationOp<CPUDevice>); 216 217 #if GOOGLE_CUDA 218 template <> 219 class AdjustSaturationOp<GPUDevice> : public AdjustSaturationOpBase { 220 public: 221 explicit AdjustSaturationOp(OpKernelConstruction* context) 222 : AdjustSaturationOpBase(context) {} 223 224 void DoCompute(OpKernelContext* context, 225 const ComputeOptions& options) override { 226 const Tensor* input = options.input; 227 const Tensor* scale = options.scale; 228 Tensor* output = options.output; 229 const int64 number_of_elements = input->NumElements(); 230 GPUDevice device = context->eigen_gpu_device(); 231 const auto stream = device.stream(); 232 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 233 if (number_of_elements > 0) { 234 const float* input_data = input->flat<float>().data(); 235 const float* scale_data = scale->flat<float>().data(); 236 float* const output_data = output->flat<float>().data(); 237 functor::AdjustSaturationGPU()(&device, number_of_elements, input_data, 238 scale_data, output_data); 239 } 240 } 241 }; 242 243 REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_GPU), 244 AdjustSaturationOp<GPUDevice>); 245 246 #endif 247 248 } // namespace tensorflow 249