Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/lib/util.h"
     17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 // Converts 'input' from RGB format to HSV format.
     25 // 'shape' is the shape of the red/green/blue tensors.
     26 std::array<xla::ComputationDataHandle, 3> RGBToHSV(
     27     XlaOpKernelContext* ctx, xla::ComputationBuilder* b,
     28     const std::array<xla::ComputationDataHandle, 3>& rgb, DataType dtype,
     29     const TensorShape& shape) {
     30   auto zero = XlaHelpers::Zero(b, dtype);
     31   auto one = XlaHelpers::One(b, dtype);
     32 
     33   auto red = rgb[0];
     34   auto green = rgb[1];
     35   auto blue = rgb[2];
     36   auto value = b->Max(b->Max(red, green), blue);
     37   auto minimum = b->Min(b->Min(red, green), blue);
     38   auto range = b->Sub(value, minimum);
     39 
     40   auto zeros = b->Broadcast(zero, shape.dim_sizes());
     41   auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros);
     42 
     43   auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
     44 
     45   auto hue = b->Select(b->Eq(green, value),
     46                        b->Add(b->Mul(norm, b->Sub(blue, red)),
     47                               XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
     48                        b->Add(b->Mul(norm, b->Sub(red, green)),
     49                               XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
     50   hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue);
     51   hue = b->Select(b->Gt(range, zero), hue, zeros);
     52   hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue);
     53   return {hue, saturation, value};
     54 }
     55 
     56 // Converts 'input' from HSV format to RGB format.
     57 std::array<xla::ComputationDataHandle, 3> HSVToRGB(
     58     xla::ComputationBuilder* b,
     59     const std::array<xla::ComputationDataHandle, 3>& hsv, DataType dtype) {
     60   xla::ComputationDataHandle hue = hsv[0];
     61   xla::ComputationDataHandle saturation = hsv[1];
     62   xla::ComputationDataHandle value = hsv[2];
     63   auto zero = XlaHelpers::Zero(b, dtype);
     64   auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
     65   auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
     66   auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0);
     67   auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
     68   auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
     69 
     70   auto dh = b->Mul(hue, six);
     71   auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one);
     72   auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one);
     73   auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one);
     74   auto one_minus_s = b->Sub(one, saturation);
     75 
     76   auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value);
     77   auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value);
     78   auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value);
     79   return {red, green, blue};
     80 }
     81 
     82 class RGBToHSVOp : public XlaOpKernel {
     83  public:
     84   explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
     85 
     86   void Compile(XlaOpKernelContext* context) override {
     87     const TensorShape input_shape = context->InputShape(0);
     88     OP_REQUIRES(context, input_shape.dims() >= 1,
     89                 errors::InvalidArgument("input must be at least 1D",
     90                                         input_shape.DebugString()));
     91     int channel_dim = input_shape.dims() - 1;
     92     int64 channels = input_shape.dim_size(channel_dim);
     93     OP_REQUIRES(
     94         context, channels == 3,
     95         errors::FailedPrecondition("input must have 3 channels but input has ",
     96                                    channels, " channels."));
     97 
     98     xla::ComputationBuilder* b = context->builder();
     99     xla::ComputationDataHandle input = context->Input(0);
    100 
    101     xla::ComputationDataHandle red =
    102         b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
    103                       /*dimno=*/channel_dim);
    104     xla::ComputationDataHandle green =
    105         b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
    106                       /*dimno=*/channel_dim);
    107     xla::ComputationDataHandle blue =
    108         b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
    109                       /*dimno=*/channel_dim);
    110     TensorShape channel_shape = input_shape;
    111     channel_shape.set_dim(channel_dim, 1);
    112     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
    113                         channel_shape);
    114 
    115     context->SetOutput(0, b->ConcatInDim(hsv, channel_dim));
    116   }
    117 };
    118 REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
    119 
    120 class HSVToRGBOp : public XlaOpKernel {
    121  public:
    122   explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
    123 
    124   void Compile(XlaOpKernelContext* context) override {
    125     const TensorShape input_shape = context->InputShape(0);
    126     OP_REQUIRES(context, input_shape.dims() >= 1,
    127                 errors::InvalidArgument("input must be at least 1D",
    128                                         input_shape.DebugString()));
    129     int channel_dim = input_shape.dims() - 1;
    130     int64 channels = input_shape.dim_size(channel_dim);
    131     OP_REQUIRES(
    132         context, channels == 3,
    133         errors::FailedPrecondition("input must have 3 channels but input has ",
    134                                    channels, " channels."));
    135 
    136     xla::ComputationBuilder* b = context->builder();
    137     xla::ComputationDataHandle input = context->Input(0);
    138     xla::ComputationDataHandle hue =
    139         b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
    140                       /*dimno=*/channel_dim);
    141     xla::ComputationDataHandle saturation =
    142         b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
    143                       /*dimno=*/channel_dim);
    144     xla::ComputationDataHandle value =
    145         b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
    146                       /*dimno=*/channel_dim);
    147 
    148     auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
    149                         context->input_type(0));
    150 
    151     context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
    152   }
    153 };
    154 REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
    155 
    156 class AdjustContrastOpV2 : public XlaOpKernel {
    157  public:
    158   explicit AdjustContrastOpV2(OpKernelConstruction* context)
    159       : XlaOpKernel(context) {}
    160 
    161   void Compile(XlaOpKernelContext* context) override {
    162     const TensorShape& input_shape = context->InputShape(0);
    163     const TensorShape& factor_shape = context->InputShape(1);
    164     OP_REQUIRES(context, input_shape.dims() >= 3,
    165                 errors::InvalidArgument("input must be at least 3-D, got shape",
    166                                         input_shape.DebugString()));
    167     int height_dim = input_shape.dims() - 3;
    168     int width_dim = input_shape.dims() - 2;
    169     int channel_dim = input_shape.dims() - 1;
    170     const int64 height = input_shape.dim_size(height_dim);
    171     const int64 width = input_shape.dim_size(width_dim);
    172 
    173     OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape),
    174                 errors::InvalidArgument("contrast_factor must be scalar: ",
    175                                         factor_shape.DebugString()));
    176 
    177     xla::ComputationBuilder* b = context->builder();
    178     xla::ComputationDataHandle input = context->Input(0);
    179     xla::ComputationDataHandle factor = context->Input(1);
    180 
    181     DataType type = context->input_type(0);
    182 
    183     auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type),
    184                             /*computation=*/*context->GetOrCreateAdd(type),
    185                             {height_dim, width_dim});
    186     output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
    187 
    188     std::vector<int64> broadcast_dims(input_shape.dims() - 2);
    189     std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
    190     broadcast_dims.back() = channel_dim;
    191     output = b->Add(b->Mul(input, factor),
    192                     b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)),
    193                     broadcast_dims);
    194     context->SetOutput(0, output);
    195   }
    196 };
    197 REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2);
    198 
    199 class AdjustSaturationOp : public XlaOpKernel {
    200  public:
    201   explicit AdjustSaturationOp(OpKernelConstruction* context)
    202       : XlaOpKernel(context) {}
    203 
    204   void Compile(XlaOpKernelContext* context) override {
    205     const TensorShape& input_shape = context->InputShape(0);
    206     const TensorShape& scale_shape = context->InputShape(1);
    207     OP_REQUIRES(context, input_shape.dims() >= 3,
    208                 errors::InvalidArgument("input must be at least 3-D, got shape",
    209                                         input_shape.DebugString()));
    210     OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape),
    211                 errors::InvalidArgument("scale must be scalar: ",
    212                                         scale_shape.DebugString()));
    213     const int channel_dim = input_shape.dims() - 1;
    214     const int64 channels = input_shape.dim_size(channel_dim);
    215     OP_REQUIRES(
    216         context, channels == 3,
    217         errors::InvalidArgument("input must have 3 channels but instead has ",
    218                                 channels, " channels."));
    219 
    220     xla::ComputationBuilder* b = context->builder();
    221     xla::ComputationDataHandle input = context->Input(0);
    222     xla::ComputationDataHandle scale = context->Input(1);
    223 
    224     DataType type = context->input_type(0);
    225 
    226     xla::ComputationDataHandle red =
    227         b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
    228                       /*dimno=*/channel_dim);
    229     xla::ComputationDataHandle green =
    230         b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
    231                       /*dimno=*/channel_dim);
    232     xla::ComputationDataHandle blue =
    233         b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
    234                       /*dimno=*/channel_dim);
    235     TensorShape channel_shape = input_shape;
    236     channel_shape.set_dim(channel_dim, 1);
    237     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
    238                         channel_shape);
    239 
    240     hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale),
    241                       XlaHelpers::One(b, type));
    242 
    243     auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
    244 
    245     context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
    246   }
    247 };
    248 REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
    249 
    250 class AdjustHueOp : public XlaOpKernel {
    251  public:
    252   explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
    253 
    254   void Compile(XlaOpKernelContext* context) override {
    255     const TensorShape& input_shape = context->InputShape(0);
    256     const TensorShape& delta_shape = context->InputShape(1);
    257     OP_REQUIRES(context, input_shape.dims() >= 3,
    258                 errors::InvalidArgument("input must be at least 3-D, got shape",
    259                                         input_shape.DebugString()));
    260     OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape),
    261                 errors::InvalidArgument("delta must be scalar: ",
    262                                         delta_shape.DebugString()));
    263     const int channel_dim = input_shape.dims() - 1;
    264     const int64 channels = input_shape.dim_size(channel_dim);
    265     OP_REQUIRES(
    266         context, channels == 3,
    267         errors::InvalidArgument("input must have 3 channels but instead has ",
    268                                 channels, " channels."));
    269 
    270     xla::ComputationBuilder* b = context->builder();
    271     xla::ComputationDataHandle input = context->Input(0);
    272     xla::ComputationDataHandle delta = context->Input(1);
    273 
    274     DataType type = context->input_type(0);
    275 
    276     xla::ComputationDataHandle red =
    277         b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
    278                       /*dimno=*/channel_dim);
    279     xla::ComputationDataHandle green =
    280         b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
    281                       /*dimno=*/channel_dim);
    282     xla::ComputationDataHandle blue =
    283         b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
    284                       /*dimno=*/channel_dim);
    285     TensorShape channel_shape = input_shape;
    286     channel_shape.set_dim(channel_dim, 1);
    287     auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
    288                         channel_shape);
    289 
    290     auto zero = XlaHelpers::Zero(b, type);
    291     auto one = XlaHelpers::One(b, type);
    292 
    293     auto& hue = hsv[0];
    294     hue = b->Rem(b->Add(hsv[0], delta), one);
    295     hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue);
    296 
    297     auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
    298 
    299     context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
    300   }
    301 };
    302 REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
    303 
    304 }  // namespace
    305 }  // namespace tensorflow
    306