1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/util.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20 21 namespace tensorflow { 22 namespace { 23 24 // Converts 'input' from RGB format to HSV format. 25 // 'shape' is the shape of the red/green/blue tensors. 26 std::array<xla::ComputationDataHandle, 3> RGBToHSV( 27 XlaOpKernelContext* ctx, xla::ComputationBuilder* b, 28 const std::array<xla::ComputationDataHandle, 3>& rgb, DataType dtype, 29 const TensorShape& shape) { 30 auto zero = XlaHelpers::Zero(b, dtype); 31 auto one = XlaHelpers::One(b, dtype); 32 33 auto red = rgb[0]; 34 auto green = rgb[1]; 35 auto blue = rgb[2]; 36 auto value = b->Max(b->Max(red, green), blue); 37 auto minimum = b->Min(b->Min(red, green), blue); 38 auto range = b->Sub(value, minimum); 39 40 auto zeros = b->Broadcast(zero, shape.dim_sizes()); 41 auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); 42 43 auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); 44 45 auto hue = b->Select(b->Eq(green, value), 46 b->Add(b->Mul(norm, b->Sub(blue, red)), 47 XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), 48 b->Add(b->Mul(norm, b->Sub(red, green)), 49 XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); 50 hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); 51 hue = b->Select(b->Gt(range, zero), hue, zeros); 52 hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); 53 return {hue, saturation, value}; 54 } 55 56 // Converts 'input' from HSV format to RGB format. 57 std::array<xla::ComputationDataHandle, 3> HSVToRGB( 58 xla::ComputationBuilder* b, 59 const std::array<xla::ComputationDataHandle, 3>& hsv, DataType dtype) { 60 xla::ComputationDataHandle hue = hsv[0]; 61 xla::ComputationDataHandle saturation = hsv[1]; 62 xla::ComputationDataHandle value = hsv[2]; 63 auto zero = XlaHelpers::Zero(b, dtype); 64 auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); 65 auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); 66 auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0); 67 auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); 68 auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); 69 70 auto dh = b->Mul(hue, six); 71 auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); 72 auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); 73 auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); 74 auto one_minus_s = b->Sub(one, saturation); 75 76 auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); 77 auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); 78 auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); 79 return {red, green, blue}; 80 } 81 82 class RGBToHSVOp : public XlaOpKernel { 83 public: 84 explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {} 85 86 void Compile(XlaOpKernelContext* context) override { 87 const TensorShape input_shape = context->InputShape(0); 88 OP_REQUIRES(context, input_shape.dims() >= 1, 89 errors::InvalidArgument("input must be at least 1D", 90 input_shape.DebugString())); 91 int channel_dim = input_shape.dims() - 1; 92 int64 channels = input_shape.dim_size(channel_dim); 93 OP_REQUIRES( 94 context, channels == 3, 95 errors::FailedPrecondition("input must have 3 channels but input has ", 96 channels, " channels.")); 97 98 xla::ComputationBuilder* b = context->builder(); 99 xla::ComputationDataHandle input = context->Input(0); 100 101 xla::ComputationDataHandle red = 102 b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, 103 /*dimno=*/channel_dim); 104 xla::ComputationDataHandle green = 105 b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, 106 /*dimno=*/channel_dim); 107 xla::ComputationDataHandle blue = 108 b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, 109 /*dimno=*/channel_dim); 110 TensorShape channel_shape = input_shape; 111 channel_shape.set_dim(channel_dim, 1); 112 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), 113 channel_shape); 114 115 context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); 116 } 117 }; 118 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); 119 120 class HSVToRGBOp : public XlaOpKernel { 121 public: 122 explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {} 123 124 void Compile(XlaOpKernelContext* context) override { 125 const TensorShape input_shape = context->InputShape(0); 126 OP_REQUIRES(context, input_shape.dims() >= 1, 127 errors::InvalidArgument("input must be at least 1D", 128 input_shape.DebugString())); 129 int channel_dim = input_shape.dims() - 1; 130 int64 channels = input_shape.dim_size(channel_dim); 131 OP_REQUIRES( 132 context, channels == 3, 133 errors::FailedPrecondition("input must have 3 channels but input has ", 134 channels, " channels.")); 135 136 xla::ComputationBuilder* b = context->builder(); 137 xla::ComputationDataHandle input = context->Input(0); 138 xla::ComputationDataHandle hue = 139 b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, 140 /*dimno=*/channel_dim); 141 xla::ComputationDataHandle saturation = 142 b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, 143 /*dimno=*/channel_dim); 144 xla::ComputationDataHandle value = 145 b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, 146 /*dimno=*/channel_dim); 147 148 auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, 149 context->input_type(0)); 150 151 context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); 152 } 153 }; 154 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); 155 156 class AdjustContrastOpV2 : public XlaOpKernel { 157 public: 158 explicit AdjustContrastOpV2(OpKernelConstruction* context) 159 : XlaOpKernel(context) {} 160 161 void Compile(XlaOpKernelContext* context) override { 162 const TensorShape& input_shape = context->InputShape(0); 163 const TensorShape& factor_shape = context->InputShape(1); 164 OP_REQUIRES(context, input_shape.dims() >= 3, 165 errors::InvalidArgument("input must be at least 3-D, got shape", 166 input_shape.DebugString())); 167 int height_dim = input_shape.dims() - 3; 168 int width_dim = input_shape.dims() - 2; 169 int channel_dim = input_shape.dims() - 1; 170 const int64 height = input_shape.dim_size(height_dim); 171 const int64 width = input_shape.dim_size(width_dim); 172 173 OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape), 174 errors::InvalidArgument("contrast_factor must be scalar: ", 175 factor_shape.DebugString())); 176 177 xla::ComputationBuilder* b = context->builder(); 178 xla::ComputationDataHandle input = context->Input(0); 179 xla::ComputationDataHandle factor = context->Input(1); 180 181 DataType type = context->input_type(0); 182 183 auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), 184 /*computation=*/*context->GetOrCreateAdd(type), 185 {height_dim, width_dim}); 186 output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); 187 188 std::vector<int64> broadcast_dims(input_shape.dims() - 2); 189 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); 190 broadcast_dims.back() = channel_dim; 191 output = b->Add(b->Mul(input, factor), 192 b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), 193 broadcast_dims); 194 context->SetOutput(0, output); 195 } 196 }; 197 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2); 198 199 class AdjustSaturationOp : public XlaOpKernel { 200 public: 201 explicit AdjustSaturationOp(OpKernelConstruction* context) 202 : XlaOpKernel(context) {} 203 204 void Compile(XlaOpKernelContext* context) override { 205 const TensorShape& input_shape = context->InputShape(0); 206 const TensorShape& scale_shape = context->InputShape(1); 207 OP_REQUIRES(context, input_shape.dims() >= 3, 208 errors::InvalidArgument("input must be at least 3-D, got shape", 209 input_shape.DebugString())); 210 OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape), 211 errors::InvalidArgument("scale must be scalar: ", 212 scale_shape.DebugString())); 213 const int channel_dim = input_shape.dims() - 1; 214 const int64 channels = input_shape.dim_size(channel_dim); 215 OP_REQUIRES( 216 context, channels == 3, 217 errors::InvalidArgument("input must have 3 channels but instead has ", 218 channels, " channels.")); 219 220 xla::ComputationBuilder* b = context->builder(); 221 xla::ComputationDataHandle input = context->Input(0); 222 xla::ComputationDataHandle scale = context->Input(1); 223 224 DataType type = context->input_type(0); 225 226 xla::ComputationDataHandle red = 227 b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, 228 /*dimno=*/channel_dim); 229 xla::ComputationDataHandle green = 230 b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, 231 /*dimno=*/channel_dim); 232 xla::ComputationDataHandle blue = 233 b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, 234 /*dimno=*/channel_dim); 235 TensorShape channel_shape = input_shape; 236 channel_shape.set_dim(channel_dim, 1); 237 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), 238 channel_shape); 239 240 hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), 241 XlaHelpers::One(b, type)); 242 243 auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); 244 245 context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); 246 } 247 }; 248 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); 249 250 class AdjustHueOp : public XlaOpKernel { 251 public: 252 explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {} 253 254 void Compile(XlaOpKernelContext* context) override { 255 const TensorShape& input_shape = context->InputShape(0); 256 const TensorShape& delta_shape = context->InputShape(1); 257 OP_REQUIRES(context, input_shape.dims() >= 3, 258 errors::InvalidArgument("input must be at least 3-D, got shape", 259 input_shape.DebugString())); 260 OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape), 261 errors::InvalidArgument("delta must be scalar: ", 262 delta_shape.DebugString())); 263 const int channel_dim = input_shape.dims() - 1; 264 const int64 channels = input_shape.dim_size(channel_dim); 265 OP_REQUIRES( 266 context, channels == 3, 267 errors::InvalidArgument("input must have 3 channels but instead has ", 268 channels, " channels.")); 269 270 xla::ComputationBuilder* b = context->builder(); 271 xla::ComputationDataHandle input = context->Input(0); 272 xla::ComputationDataHandle delta = context->Input(1); 273 274 DataType type = context->input_type(0); 275 276 xla::ComputationDataHandle red = 277 b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, 278 /*dimno=*/channel_dim); 279 xla::ComputationDataHandle green = 280 b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, 281 /*dimno=*/channel_dim); 282 xla::ComputationDataHandle blue = 283 b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, 284 /*dimno=*/channel_dim); 285 TensorShape channel_shape = input_shape; 286 channel_shape.set_dim(channel_dim, 1); 287 auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), 288 channel_shape); 289 290 auto zero = XlaHelpers::Zero(b, type); 291 auto one = XlaHelpers::One(b, type); 292 293 auto& hue = hsv[0]; 294 hue = b->Rem(b->Add(hsv[0], delta), one); 295 hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); 296 297 auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); 298 299 context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); 300 } 301 }; 302 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); 303 304 } // namespace 305 } // namespace tensorflow 306