Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
     17 #define TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/tensor_types.h"
     20 
     21 namespace tensorflow {
     22 namespace functor {
     23 
     24 // Functor used by AdjustContrastOp to do the computations.
     25 template <typename Device, typename T>
     26 struct AdjustContrast {
     27   void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
     28                   typename TTypes<float>::ConstScalar contrast_factor,
     29                   typename TTypes<float>::ConstScalar min_value,
     30                   typename TTypes<float>::ConstScalar max_value,
     31                   typename TTypes<float, 4>::Tensor mean_values,
     32                   typename TTypes<float, 4>::Tensor output) {
     33     const int batch = input.dimension(0);
     34     const int height = input.dimension(1);
     35     const int width = input.dimension(2);
     36     const int channels = input.dimension(3);
     37 
     38     Eigen::array<int, 4> scalar_broadcast;
     39     scalar_broadcast[0] = batch;
     40     scalar_broadcast[1] = height;
     41     scalar_broadcast[2] = width;
     42     scalar_broadcast[3] = channels;
     43 #if !defined(EIGEN_HAS_INDEX_LIST)
     44     Eigen::array<int, 2> reduction_axis;
     45     reduction_axis[0] = 1;
     46     reduction_axis[1] = 2;
     47     Eigen::array<int, 4> broadcast_dims;
     48     broadcast_dims[0] = 1;
     49     broadcast_dims[1] = height;
     50     broadcast_dims[2] = width;
     51     broadcast_dims[3] = 1;
     52     Eigen::Tensor<int, 4>::Dimensions reshape_dims;
     53     reshape_dims[0] = batch;
     54     reshape_dims[1] = 1;
     55     reshape_dims[2] = 1;
     56     reshape_dims[3] = channels;
     57 #else
     58     Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >
     59         reduction_axis;
     60     Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
     61         broadcast_dims;
     62     broadcast_dims.set(1, height);
     63     broadcast_dims.set(2, width);
     64     Eigen::IndexList<int, Eigen::type2index<1>, Eigen::type2index<1>, int>
     65         reshape_dims;
     66     reshape_dims.set(0, batch);
     67     reshape_dims.set(3, channels);
     68 #endif
     69     Eigen::Sizes<1, 1, 1, 1> scalar;
     70     float num_reduced_coeffs = height * width;
     71     mean_values.device(d) =
     72         (input.template cast<float>().sum(reduction_axis).eval() /
     73          num_reduced_coeffs)
     74             .reshape(reshape_dims)
     75             .broadcast(broadcast_dims);
     76 
     77     auto contrast_factor_tensor =
     78         contrast_factor.reshape(scalar).broadcast(scalar_broadcast);
     79     auto adjusted =
     80         (input.template cast<float>() - mean_values) * contrast_factor_tensor +
     81         mean_values;
     82     auto min_bcast = min_value.reshape(scalar).broadcast(scalar_broadcast);
     83     auto max_bcast = max_value.reshape(scalar).broadcast(scalar_broadcast);
     84     // TODO(wicke): This is rather slow and should be re-written as pure cuda.
     85     output.device(d) = adjusted.cwiseMin(max_bcast).cwiseMax(min_bcast);
     86   }
     87 };
     88 
     89 // Functor used by AdjustContrastOpv2 to do the computations.
     90 template <typename Device>
     91 struct AdjustContrastv2 {
     92   void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor input,
     93                   typename TTypes<float>::ConstScalar contrast_factor,
     94                   typename TTypes<float, 4>::Tensor output) {
     95     const int batch = input.dimension(0);
     96     const int height = input.dimension(1);
     97     const int width = input.dimension(2);
     98     const int channels = input.dimension(3);
     99 
    100     Eigen::array<int, 4> scalar_broadcast;
    101     scalar_broadcast[0] = batch;
    102     scalar_broadcast[1] = height;
    103     scalar_broadcast[2] = width;
    104     scalar_broadcast[3] = channels;
    105 #if !defined(EIGEN_HAS_INDEX_LIST)
    106     Eigen::array<int, 2> reduction_axis;
    107     reduction_axis[0] = 0;
    108     reduction_axis[1] = 1;
    109     Eigen::array<int, 4> broadcast_dims;
    110     broadcast_dims[0] = 1;
    111     broadcast_dims[1] = height;
    112     broadcast_dims[2] = width;
    113     broadcast_dims[3] = 1;
    114     Eigen::Tensor<int, 4>::Dimensions reshape_dims;
    115     reshape_dims[0] = batch;
    116     reshape_dims[1] = 1;
    117     reshape_dims[2] = 1;
    118     reshape_dims[3] = channels;
    119     Eigen::array<int, 4> reduced_dims_first;
    120     reduced_dims_first[0] = 1;
    121     reduced_dims_first[1] = 2;
    122     reduced_dims_first[2] = 0;
    123     reduced_dims_first[3] = 3;
    124 #else
    125     Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1> >
    126         reduction_axis;
    127     Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
    128         broadcast_dims;
    129     broadcast_dims.set(1, height);
    130     broadcast_dims.set(2, width);
    131     Eigen::IndexList<int, Eigen::type2index<1>, Eigen::type2index<1>, int>
    132         reshape_dims;
    133     reshape_dims.set(0, batch);
    134     reshape_dims.set(3, channels);
    135     Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>,
    136                      Eigen::type2index<0>, Eigen::type2index<3> >
    137         reduced_dims_first;
    138 #endif
    139     Eigen::Sizes<1, 1, 1, 1> scalar;
    140     float num_reduced_coeffs = height * width;
    141     output.device(d) =
    142         (input.shuffle(reduced_dims_first).sum(reduction_axis).eval() /
    143          num_reduced_coeffs)
    144             .reshape(reshape_dims)
    145             .broadcast(broadcast_dims);
    146     auto contrast_factor_tensor =
    147         contrast_factor.reshape(scalar).broadcast(scalar_broadcast);
    148     auto adjusted = (input - output) * contrast_factor_tensor;
    149     output.device(d) += adjusted;
    150   }
    151 };
    152 
    153 }  // namespace functor
    154 }  // namespace tensorflow
    155 
    156 #endif  // TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
    157